mirror of
https://github.com/run-llama/modal_finetune_sql.git
synced 2026-07-01 21:44:58 -04:00
Merge pull request #1 from run-llama/jerry/add_finetune_tutorial
[wip] add sql finetuning tutorial with modal
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
.venv
|
||||
*__pycache__*
|
||||
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Modal Labs
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -1 +1,34 @@
|
||||
# modal_finetune_sql
|
||||
# Finetuning LLaMa + Text-to-SQL
|
||||
|
||||
This walkthrough shows you how to fine-tune LLaMa-7B on a Text-to-SQL dataset, and then use it for inference against
|
||||
any database of structured data using LlamaIndex.
|
||||
|
||||
|
||||
This code is taken and adapted from the Modal `doppel-bot` repo: https://github.com/modal-labs/doppel-bot.
|
||||
|
||||
### Stack
|
||||
|
||||
- LlamaIndex
|
||||
- Modal
|
||||
- Hugging Face datasets
|
||||
- OpenLLaMa
|
||||
- Peft
|
||||
|
||||
|
||||
### Steps for running
|
||||
|
||||
Please see the notebook `tutorial.ipynb` for full instructions.
|
||||
|
||||
In the meantime you can run each step individually as below:
|
||||
|
||||
Loading data:
|
||||
`modal run src.load_data_sql`
|
||||
|
||||
Finetuning:
|
||||
`modal run --detach src.finetune_sql`
|
||||
|
||||
Inference:
|
||||
`modal run src.inference_sql_llamaindex::main --query "Which city has the highest population?" --sqlite-file-path "nbs/cities.db"`
|
||||
|
||||
(Optional) Downloading model weights:
|
||||
`modal run src.download_weights --output-dir out_model`
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
|
||||
[tool.ruff]
|
||||
ignore = [
|
||||
'E501',
|
||||
'E741',
|
||||
]
|
||||
select = [
|
||||
'E',
|
||||
'F',
|
||||
'W',
|
||||
]
|
||||
@@ -0,0 +1,4 @@
|
||||
modal-client==0.50.3044
|
||||
llama-index==0.8.2.post1
|
||||
datasets==2.14.4
|
||||
peft
|
||||
@@ -0,0 +1,82 @@
|
||||
from modal import Image, Stub, NetworkFileSystem, Dict
|
||||
import random
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
VOL_MOUNT_PATH = Path("/vol")
|
||||
|
||||
WANDB_PROJECT = "test-finetune-modal"
|
||||
|
||||
MODEL_PATH = "/model"
|
||||
|
||||
|
||||
def download_models():
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
model_name = "openlm-research/open_llama_7b"
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained(model_name)
|
||||
model.save_pretrained(MODEL_PATH)
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(model_name)
|
||||
tokenizer.save_pretrained(MODEL_PATH)
|
||||
|
||||
|
||||
openllama_image = (
|
||||
Image.micromamba()
|
||||
.micromamba_install(
|
||||
"cudatoolkit=11.7",
|
||||
"cudnn=8.1.0",
|
||||
"cuda-nvcc",
|
||||
channels=["conda-forge", "nvidia"],
|
||||
)
|
||||
.apt_install("git")
|
||||
.pip_install(
|
||||
"accelerate==0.18.0",
|
||||
"bitsandbytes==0.37.0",
|
||||
"bitsandbytes-cuda117==0.26.0.post2",
|
||||
"datasets==2.10.1",
|
||||
"fire==0.5.0",
|
||||
"gradio==3.23.0",
|
||||
"peft @ git+https://github.com/huggingface/peft.git@e536616888d51b453ed354a6f1e243fecb02ea08",
|
||||
"transformers @ git+https://github.com/huggingface/transformers.git@a92e0ad2e20ef4ce28410b5e05c5d63a5a304e65",
|
||||
"torch==2.0.0",
|
||||
"torchvision==0.15.1",
|
||||
"sentencepiece==0.1.97",
|
||||
"llama-index==0.8.1",
|
||||
"sentence-transformers",
|
||||
)
|
||||
.run_function(download_models)
|
||||
.pip_install("wandb==0.15.0")
|
||||
)
|
||||
|
||||
stub = Stub(name="doppel-bot", image=openllama_image)
|
||||
stub.model_dict = Dict.new()
|
||||
stub.data_dict = Dict.new()
|
||||
|
||||
output_vol = NetworkFileSystem.new(cloud="gcp").persisted("doppelbot-vol")
|
||||
|
||||
|
||||
def generate_prompt_sql(input, context, output=""):
|
||||
return f"""You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables.
|
||||
|
||||
You must output the SQL query that answers the question.
|
||||
|
||||
### Input:
|
||||
{input}
|
||||
|
||||
### Context:
|
||||
{context}
|
||||
|
||||
### Response:
|
||||
{output}"""
|
||||
|
||||
|
||||
def get_data_path(data_dir: str = "data_sql") -> Path:
|
||||
return VOL_MOUNT_PATH / data_dir / "data_sql.jsonl"
|
||||
|
||||
def get_model_path(data_dir: str = "data_sql", checkpoint: Optional[str] = None) -> Path:
|
||||
path = VOL_MOUNT_PATH / data_dir
|
||||
if checkpoint:
|
||||
path = path / checkpoint
|
||||
return path
|
||||
@@ -0,0 +1,2 @@
|
||||
{"input": "How many heads of the departments are older than 56 ?", "context": "CREATE TABLE head (age INTEGER)"}
|
||||
{"input": "List the name, born state and age of the heads of departments ordered by age.", "context": "CREATE TABLE head (name VARCHAR, born_state VARCHAR, age VARCHAR)"}
|
||||
@@ -0,0 +1,55 @@
|
||||
"""Download weights."""
|
||||
|
||||
from .common import (
|
||||
stub, output_vol, VOL_MOUNT_PATH, get_model_path
|
||||
)
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
@stub.function(
|
||||
network_file_systems={VOL_MOUNT_PATH.as_posix(): output_vol},
|
||||
cloud="gcp"
|
||||
)
|
||||
def load_model(model_dir: str = "data_sql"):
|
||||
"""Load model."""
|
||||
path = get_model_path(model_dir=model_dir)
|
||||
config_path = path / "adapter_config.json"
|
||||
model_path = path / "adapter_model.bin"
|
||||
|
||||
config_data = json.load(open(config_path))
|
||||
with open(model_path, "rb") as f:
|
||||
model_data = f.read()
|
||||
|
||||
print(f'loaded config, model data from {path}')
|
||||
|
||||
# read data, put this in `model_dict` on stub
|
||||
stub.model_dict["config"] = config_data
|
||||
stub.model_dict["model"] = model_data
|
||||
|
||||
@stub.local_entrypoint()
|
||||
def main(output_dir: str, model_dir: str = "data_sql"):
|
||||
# copy adapter_config.json and adapter_model.bin files into dict
|
||||
load_model.call(model_dir=model_dir)
|
||||
model_data = stub.model_dict["model"]
|
||||
config_data = stub.model_dict["config"]
|
||||
|
||||
print(f"Loaded model data, storing in {output_dir}")
|
||||
|
||||
# store locally
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
out_model_path = Path(output_dir) / "adapter_model.bin"
|
||||
out_config_path = Path(output_dir) / "adapter_config.json"
|
||||
|
||||
with open(out_model_path, "wb") as f:
|
||||
f.write(model_data)
|
||||
|
||||
with open(out_config_path, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
print("Done!")
|
||||
|
||||
|
||||
|
||||
+109
@@ -0,0 +1,109 @@
|
||||
from typing import Optional
|
||||
|
||||
from modal import gpu, method, Retries
|
||||
from modal.cls import ClsMixin
|
||||
import json
|
||||
|
||||
from .common import (
|
||||
output_vol,
|
||||
stub,
|
||||
VOL_MOUNT_PATH,
|
||||
get_data_path,
|
||||
generate_prompt_sql
|
||||
)
|
||||
from .inference_utils import OpenLlamaLLM
|
||||
|
||||
|
||||
@stub.function(
|
||||
gpu="A100",
|
||||
retries=Retries(
|
||||
max_retries=3,
|
||||
initial_delay=5.0,
|
||||
backoff_coefficient=2.0,
|
||||
),
|
||||
timeout=60 * 60 * 2,
|
||||
network_file_systems={VOL_MOUNT_PATH.as_posix(): output_vol},
|
||||
cloud="gcp",
|
||||
)
|
||||
def run_evals(
|
||||
sample_data,
|
||||
model_dir: str = "data_sql",
|
||||
use_finetuned_model: bool = True
|
||||
):
|
||||
llm = OpenLlamaLLM(
|
||||
model_dir=model_dir, max_new_tokens=256, use_finetuned_model=use_finetuned_model
|
||||
)
|
||||
inputs_outputs = []
|
||||
for row_dict in sample_data:
|
||||
prompt = generate_prompt_sql(row_dict["input"], row_dict["context"])
|
||||
completion = llm.complete(
|
||||
prompt,
|
||||
do_sample=True,
|
||||
temperature=0.3,
|
||||
top_p=0.85,
|
||||
top_k=40,
|
||||
num_beams=1,
|
||||
max_new_tokens=600,
|
||||
repetition_penalty=1.2,
|
||||
)
|
||||
inputs_outputs.append((row_dict, completion.text))
|
||||
return inputs_outputs
|
||||
|
||||
|
||||
@stub.function(
|
||||
gpu="A100",
|
||||
retries=Retries(
|
||||
max_retries=3,
|
||||
initial_delay=5.0,
|
||||
backoff_coefficient=2.0,
|
||||
),
|
||||
timeout=60 * 60 * 2,
|
||||
network_file_systems={VOL_MOUNT_PATH.as_posix(): output_vol},
|
||||
cloud="gcp",
|
||||
)
|
||||
def run_evals_all(
|
||||
data_dir: str = "data_sql",
|
||||
model_dir: str = "data_sql",
|
||||
num_samples: int = 10,
|
||||
):
|
||||
# evaluate a sample from the same training set
|
||||
from datasets import load_dataset
|
||||
|
||||
data_path = get_data_path(data_dir).as_posix()
|
||||
data = load_dataset("json", data_files=data_path)
|
||||
|
||||
# load sample data
|
||||
sample_data = data["train"].shuffle().select(range(num_samples))
|
||||
|
||||
print('*** Running inference with finetuned model ***')
|
||||
inputs_outputs_0 = run_evals(
|
||||
sample_data=sample_data,
|
||||
model_dir=model_dir,
|
||||
use_finetuned_model=True
|
||||
)
|
||||
|
||||
print('*** Running inference with base model ***')
|
||||
input_outputs_1 = run_evals(
|
||||
sample_data=sample_data,
|
||||
model_dir=model_dir,
|
||||
use_finetuned_model=False
|
||||
)
|
||||
|
||||
return inputs_outputs_0, input_outputs_1
|
||||
|
||||
|
||||
|
||||
@stub.local_entrypoint()
|
||||
def main(data_dir: str = "data_sql", model_dir: str = "data_sql", num_samples: int = 10):
|
||||
"""Main function."""
|
||||
inputs_outputs_0, input_outputs_1 = run_evals_all.call(
|
||||
data_dir=data_dir,
|
||||
model_dir=model_dir,
|
||||
num_samples=num_samples
|
||||
)
|
||||
for idx, (row_dict, completion) in enumerate(inputs_outputs_0):
|
||||
print('************ Row {idx} ************')
|
||||
print(f"Input {idx}: " + str(row_dict))
|
||||
print(f"Output {idx} (finetuned model): " + str(completion))
|
||||
print(f"Output {idx} (base model): " + str(input_outputs_1[idx][1]))
|
||||
print('***********************************')
|
||||
@@ -0,0 +1,251 @@
|
||||
from modal import Secret
|
||||
from datetime import datetime
|
||||
import os
|
||||
from math import ceil
|
||||
|
||||
from .common import (
|
||||
MODEL_PATH,
|
||||
VOL_MOUNT_PATH,
|
||||
WANDB_PROJECT,
|
||||
output_vol,
|
||||
stub,
|
||||
get_data_path,
|
||||
get_model_path,
|
||||
generate_prompt_sql,
|
||||
)
|
||||
|
||||
|
||||
# This code is adapter from https://github.com/tloen/alpaca-lora/blob/65fb8225c09af81feb5edb1abb12560f02930703/finetune.py
|
||||
# with modifications mainly to expose more parameters to the user.
|
||||
def _train(
|
||||
# model/data params
|
||||
base_model: str,
|
||||
data,
|
||||
output_dir: str = "./lora-alpaca",
|
||||
eval_steps: int = 20,
|
||||
save_steps: int = 20,
|
||||
# training hyperparams
|
||||
batch_size: int = 128,
|
||||
micro_batch_size: int = 32,
|
||||
max_steps: int = 200,
|
||||
learning_rate: float = 3e-4,
|
||||
cutoff_len: int = 512,
|
||||
val_set_size: int = 100,
|
||||
# lora hyperparams
|
||||
lora_r: int = 16,
|
||||
lora_alpha: int = 16,
|
||||
lora_dropout: float = 0.05,
|
||||
lora_target_modules=[
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
],
|
||||
# llm hyperparams
|
||||
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
||||
add_eos_token: bool = True,
|
||||
group_by_length: bool = True, # faster, but produces an odd training loss curve
|
||||
# wandb params
|
||||
wandb_project: str = "",
|
||||
wandb_run_name: str = "",
|
||||
wandb_watch: str = "", # options: false | gradients | all
|
||||
wandb_log_model: str = "", # options: false | true
|
||||
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
||||
):
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from peft import (
|
||||
LoraConfig,
|
||||
get_peft_model,
|
||||
get_peft_model_state_dict,
|
||||
prepare_model_for_int8_training,
|
||||
set_peft_model_state_dict,
|
||||
)
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
gradient_accumulation_steps = batch_size // micro_batch_size
|
||||
|
||||
device_map = "auto"
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
ddp = world_size != 1
|
||||
if ddp:
|
||||
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
|
||||
gradient_accumulation_steps = gradient_accumulation_steps // world_size
|
||||
|
||||
# Check if parameter passed or if set within environ
|
||||
use_wandb = len(wandb_project) > 0 or ("WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0)
|
||||
# Only overwrite environ if wandb param passed
|
||||
if len(wandb_project) > 0:
|
||||
os.environ["WANDB_PROJECT"] = wandb_project
|
||||
if len(wandb_watch) > 0:
|
||||
os.environ["WANDB_WATCH"] = wandb_watch
|
||||
if len(wandb_log_model) > 0:
|
||||
os.environ["WANDB_LOG_MODEL"] = wandb_log_model
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
load_in_8bit=True,
|
||||
torch_dtype=torch.float16,
|
||||
device_map=device_map,
|
||||
)
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(base_model, add_eos_token=True)
|
||||
|
||||
tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
|
||||
tokenizer.padding_side = "left" # Allow batched inference
|
||||
|
||||
def tokenize(prompt, add_eos_token=True):
|
||||
# there's probably a way to do this with the tokenizer settings
|
||||
# but again, gotta move fast
|
||||
result = tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=cutoff_len,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
if (
|
||||
result["input_ids"][-1] != tokenizer.eos_token_id
|
||||
and len(result["input_ids"]) < cutoff_len
|
||||
and add_eos_token
|
||||
):
|
||||
result["input_ids"].append(tokenizer.eos_token_id)
|
||||
result["attention_mask"].append(1)
|
||||
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
|
||||
return result
|
||||
|
||||
def generate_and_tokenize_prompt(data_point):
|
||||
full_prompt = generate_prompt_sql(
|
||||
data_point["input"],
|
||||
data_point["context"],
|
||||
data_point["output"],
|
||||
)
|
||||
tokenized_full_prompt = tokenize(full_prompt)
|
||||
if not train_on_inputs:
|
||||
raise NotImplementedError("not implemented yet")
|
||||
return tokenized_full_prompt
|
||||
|
||||
model = prepare_model_for_int8_training(model)
|
||||
|
||||
config = LoraConfig(
|
||||
r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=lora_target_modules,
|
||||
lora_dropout=lora_dropout,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
model = get_peft_model(model, config)
|
||||
|
||||
if resume_from_checkpoint:
|
||||
# Check the available weights and load them
|
||||
checkpoint_name = os.path.join(resume_from_checkpoint, "pytorch_model.bin") # Full checkpoint
|
||||
if not os.path.exists(checkpoint_name):
|
||||
checkpoint_name = os.path.join(
|
||||
resume_from_checkpoint, "adapter_model.bin"
|
||||
) # only LoRA model - LoRA config above has to fit
|
||||
resume_from_checkpoint = False # So the trainer won't try loading its state
|
||||
# The two files above have a different name depending on how they were saved, but are actually the same.
|
||||
if os.path.exists(checkpoint_name):
|
||||
print(f"Restarting from {checkpoint_name}")
|
||||
adapters_weights = torch.load(checkpoint_name)
|
||||
set_peft_model_state_dict(model, adapters_weights)
|
||||
else:
|
||||
print(f"Checkpoint {checkpoint_name} not found")
|
||||
|
||||
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
|
||||
|
||||
if val_set_size > 0:
|
||||
train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42)
|
||||
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
|
||||
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
|
||||
else:
|
||||
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
|
||||
val_data = None
|
||||
|
||||
if not ddp and torch.cuda.device_count() > 1:
|
||||
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
|
||||
model.is_parallelizable = True
|
||||
model.model_parallel = True
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=model,
|
||||
train_dataset=train_data,
|
||||
eval_dataset=val_data,
|
||||
args=transformers.TrainingArguments(
|
||||
per_device_train_batch_size=micro_batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
warmup_steps=100,
|
||||
max_steps=max_steps,
|
||||
learning_rate=learning_rate,
|
||||
fp16=True,
|
||||
logging_steps=10,
|
||||
optim="adamw_torch",
|
||||
evaluation_strategy="steps" if val_set_size > 0 else "no",
|
||||
save_strategy="steps",
|
||||
eval_steps=eval_steps if val_set_size > 0 else None,
|
||||
save_steps=save_steps,
|
||||
output_dir=output_dir,
|
||||
# save_total_limit=3,
|
||||
load_best_model_at_end=False,
|
||||
ddp_find_unused_parameters=False if ddp else None,
|
||||
group_by_length=group_by_length,
|
||||
report_to="wandb" if use_wandb else "none",
|
||||
run_name=wandb_run_name if use_wandb else None,
|
||||
),
|
||||
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
||||
),
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
old_state_dict = model.state_dict
|
||||
model.state_dict = (lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())).__get__(
|
||||
model, type(model)
|
||||
)
|
||||
|
||||
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||
model = torch.compile(model)
|
||||
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
model.save_pretrained(output_dir)
|
||||
|
||||
print("\n If there's a warning about missing keys above, please disregard :)")
|
||||
|
||||
|
||||
@stub.function(
|
||||
gpu="A100",
|
||||
# TODO: Modal should support optional secrets.
|
||||
secret=Secret.from_name("my-wandb-secret") if WANDB_PROJECT else None,
|
||||
timeout=60 * 60 * 2,
|
||||
network_file_systems={VOL_MOUNT_PATH: output_vol},
|
||||
cloud="oci",
|
||||
allow_cross_region_volumes=True,
|
||||
)
|
||||
def finetune(data_dir: str = "data_sql", model_dir: str = "data_sql"):
|
||||
from datasets import load_dataset
|
||||
|
||||
data_path = get_data_path(data_dir).as_posix()
|
||||
data = load_dataset("json", data_files=data_path)
|
||||
|
||||
num_samples = len(data["train"])
|
||||
val_set_size = ceil(0.1 * num_samples)
|
||||
print(f"Loaded {num_samples} samples. ")
|
||||
|
||||
_train(
|
||||
MODEL_PATH,
|
||||
data,
|
||||
val_set_size=val_set_size,
|
||||
output_dir=get_model_path(model_dir).as_posix(),
|
||||
wandb_project=WANDB_PROJECT,
|
||||
wandb_run_name=f"openllama-{datetime.now().strftime('%Y-%m-%d-%H-%M')}",
|
||||
)
|
||||
|
||||
# Delete scraped data after fine-tuning
|
||||
os.remove(data_path)
|
||||
@@ -0,0 +1,98 @@
|
||||
from typing import Optional
|
||||
|
||||
from modal import gpu, method, Retries
|
||||
from modal.cls import ClsMixin
|
||||
import json
|
||||
|
||||
from .common import (
|
||||
MODEL_PATH,
|
||||
output_vol,
|
||||
stub,
|
||||
VOL_MOUNT_PATH,
|
||||
get_model_path,
|
||||
generate_prompt_sql
|
||||
)
|
||||
from .inference_utils import OpenLlamaLLM
|
||||
|
||||
from llama_index.callbacks import CallbackManager
|
||||
from llama_index.llms import (
|
||||
CustomLLM,
|
||||
LLMMetadata,
|
||||
CompletionResponse,
|
||||
CompletionResponseGen,
|
||||
)
|
||||
from llama_index.llms.base import llm_completion_callback
|
||||
from llama_index.indices.struct_store.sql_query import NLSQLTableQueryEngine
|
||||
from llama_index import SQLDatabase, ServiceContext, Prompt
|
||||
from typing import Any
|
||||
|
||||
|
||||
@stub.function(
|
||||
gpu="A100",
|
||||
retries=Retries(
|
||||
max_retries=3,
|
||||
initial_delay=5.0,
|
||||
backoff_coefficient=2.0,
|
||||
),
|
||||
timeout=60 * 60 * 2,
|
||||
network_file_systems={VOL_MOUNT_PATH.as_posix(): output_vol},
|
||||
cloud="gcp",
|
||||
)
|
||||
def run_query(query: str, model_dir: str = "data_sql", use_finetuned_model: bool = True):
|
||||
"""Run query."""
|
||||
import pandas as pd
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
# define SQL database
|
||||
assert "sqlite_data" in stub.data_dict
|
||||
with open(VOL_MOUNT_PATH / "test_data.db", "wb") as fp:
|
||||
fp.write(stub.data_dict["sqlite_data"])
|
||||
|
||||
# define service context (containing custom LLM)
|
||||
print('setting up service context')
|
||||
# finetuned llama LLM
|
||||
num_output = 256
|
||||
llm = OpenLlamaLLM(
|
||||
model_dir=model_dir, max_new_tokens=num_output, use_finetuned_model=use_finetuned_model
|
||||
)
|
||||
service_context = ServiceContext.from_defaults(llm=llm)
|
||||
|
||||
sql_path = VOL_MOUNT_PATH / "test_data.db"
|
||||
engine = create_engine(f'sqlite:///{sql_path}', echo=True)
|
||||
sql_database = SQLDatabase(engine)
|
||||
|
||||
# define custom text-to-SQL prompt with generate prompt
|
||||
prompt_prefix = "Dialect: {dialect}\n\n"
|
||||
prompt_suffix = generate_prompt_sql("{query_str}", "{schema}", output="")
|
||||
sql_prompt = Prompt(prompt_prefix + prompt_suffix)
|
||||
|
||||
query_engine = NLSQLTableQueryEngine(
|
||||
sql_database,
|
||||
text_to_sql_prompt=sql_prompt,
|
||||
service_context=service_context,
|
||||
synthesize_response=False
|
||||
)
|
||||
response = query_engine.query(query)
|
||||
|
||||
print(
|
||||
f'Model output: \n'
|
||||
f'SQL Query: {str(response.metadata["sql_query"])}'
|
||||
f"Response: {response.response}"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@stub.local_entrypoint()
|
||||
def main(query: str, sqlite_file_path: str, model_dir: str = "data_sql", use_finetuned_model: str = "True"):
|
||||
"""Main function."""
|
||||
|
||||
fp = open(sqlite_file_path, "rb")
|
||||
stub.data_dict["sqlite_data"] = fp.read()
|
||||
|
||||
if use_finetuned_model == "None":
|
||||
# try both
|
||||
run_query.call(query, model_dir=model_dir, use_finetuned_model=True)
|
||||
run_query.call(query, model_dir=model_dir, use_finetuned_model=False)
|
||||
else:
|
||||
bool_toggle = use_finetuned_model == "True"
|
||||
run_query.call(query, model_dir=model_dir, use_finetuned_model=bool_toggle)
|
||||
@@ -0,0 +1,121 @@
|
||||
"""Get inference utils."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from modal import gpu
|
||||
from modal.cls import ClsMixin
|
||||
|
||||
from .common import (
|
||||
MODEL_PATH,
|
||||
output_vol,
|
||||
stub,
|
||||
VOL_MOUNT_PATH,
|
||||
get_model_path,
|
||||
)
|
||||
|
||||
from llama_index.callbacks import CallbackManager
|
||||
from llama_index.llms import (
|
||||
CustomLLM,
|
||||
LLMMetadata,
|
||||
CompletionResponse,
|
||||
CompletionResponseGen,
|
||||
)
|
||||
from llama_index.llms.base import llm_completion_callback
|
||||
from typing import Any
|
||||
|
||||
@stub.cls(
|
||||
gpu=gpu.A100(memory=20),
|
||||
network_file_systems={VOL_MOUNT_PATH: output_vol},
|
||||
)
|
||||
class OpenLlamaLLM(CustomLLM, ClsMixin):
|
||||
"""OpenLlamaLLM is a custom LLM that uses the OpenLlamaModel."""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str = "data_sql",
|
||||
max_new_tokens: int = 128,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
use_finetuned_model: bool = True,
|
||||
):
|
||||
super().__init__(callback_manager=callback_manager)
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
CHECKPOINT = get_model_path(model_dir)
|
||||
|
||||
load_8bit = False
|
||||
device = "cuda"
|
||||
|
||||
self.tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
MODEL_PATH,
|
||||
load_in_8bit=load_8bit,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
if use_finetuned_model:
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
CHECKPOINT,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
if not load_8bit:
|
||||
model.half() # seems to fix bugs for some users.
|
||||
|
||||
model.eval()
|
||||
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||
model = torch.compile(model)
|
||||
self.model = model
|
||||
self.device = device
|
||||
|
||||
self._max_new_tokens = max_new_tokens
|
||||
|
||||
@property
|
||||
def metadata(self) -> LLMMetadata:
|
||||
"""Get LLM metadata."""
|
||||
return LLMMetadata(
|
||||
context_window=2048,
|
||||
num_output=self._max_new_tokens,
|
||||
model_name="finetuned_openllama_sql"
|
||||
)
|
||||
|
||||
@llm_completion_callback()
|
||||
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
|
||||
import torch
|
||||
from transformers import GenerationConfig
|
||||
# TODO: TO fill
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt")
|
||||
input_ids = inputs["input_ids"].to(self.device)
|
||||
# tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
|
||||
# print(tokens)
|
||||
generation_config = GenerationConfig(
|
||||
**kwargs,
|
||||
)
|
||||
with torch.no_grad():
|
||||
generation_output = self.model.generate(
|
||||
input_ids=input_ids,
|
||||
generation_config=generation_config,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
max_new_tokens=self._max_new_tokens,
|
||||
)
|
||||
|
||||
s = generation_output.sequences[0]
|
||||
output = self.tokenizer.decode(s, skip_special_tokens=True)
|
||||
# NOTE: parsing response this way means that the model can mostly
|
||||
# only be used for text-to-SQL, not other purposes
|
||||
response_text = output.split("### Response:")[1].strip()
|
||||
return CompletionResponse(text=response_text)
|
||||
|
||||
@llm_completion_callback()
|
||||
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
import json
|
||||
|
||||
from modal import Retries
|
||||
from .common import (
|
||||
stub,
|
||||
VOL_MOUNT_PATH,
|
||||
output_vol,
|
||||
get_data_path
|
||||
)
|
||||
|
||||
@stub.function(
|
||||
retries=Retries(
|
||||
max_retries=3,
|
||||
initial_delay=5.0,
|
||||
backoff_coefficient=2.0,
|
||||
),
|
||||
timeout=60 * 60 * 2,
|
||||
network_file_systems={VOL_MOUNT_PATH.as_posix(): output_vol},
|
||||
cloud="gcp",
|
||||
)
|
||||
def load_data_sql(data_dir: str = "data_sql"):
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("b-mc2/sql-create-context")
|
||||
|
||||
dataset_splits = {"train": dataset["train"]}
|
||||
out_path = get_data_path(data_dir)
|
||||
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for key, ds in dataset_splits.items():
|
||||
with open(out_path, "w") as f:
|
||||
for item in ds:
|
||||
newitem = {
|
||||
"input": item["question"],
|
||||
"context": item["context"],
|
||||
"output": item["answer"],
|
||||
}
|
||||
f.write(json.dumps(newitem) + "\n")
|
||||
|
||||
+6419
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user