add fixes (#2)

This commit is contained in:
Jerry Liu
2023-08-17 09:08:44 -07:00
committed by GitHub
parent 82c6d02530
commit 9f9f21bec5
6 changed files with 3015 additions and 20 deletions
+1 -1
View File
@@ -1,6 +1,6 @@
# Finetuning LLaMa + Text-to-SQL
This walkthrough shows you how to fine-tune LLaMa-7B on a Text-to-SQL dataset, and then use it for inference against
This walkthrough shows you how to fine-tune LLaMa 2 7B on a Text-to-SQL dataset, and then use it for inference against
any database of structured data using LlamaIndex.
+1 -1
View File
@@ -13,7 +13,7 @@ MODEL_PATH = "/model"
def download_models():
from transformers import LlamaForCausalLM, LlamaTokenizer
model_name = "openlm-research/open_llama_7b"
model_name = "openlm-research/open_llama_7b_v2"
model = LlamaForCausalLM.from_pretrained(model_name)
model.save_pretrained(MODEL_PATH)
+2 -6
View File
@@ -1,8 +1,4 @@
from typing import Optional
from modal import gpu, method, Retries
from modal.cls import ClsMixin
import json
from modal import Retries
from .common import (
output_vol,
@@ -102,7 +98,7 @@ def main(data_dir: str = "data_sql", model_dir: str = "data_sql", num_samples: i
num_samples=num_samples
)
for idx, (row_dict, completion) in enumerate(inputs_outputs_0):
print('************ Row {idx} ************')
print(f'************ Row {idx} ************')
print(f"Input {idx}: " + str(row_dict))
print(f"Output {idx} (finetuned model): " + str(completion))
print(f"Output {idx} (base model): " + str(input_outputs_1[idx][1]))
+2 -2
View File
@@ -247,5 +247,5 @@ def finetune(data_dir: str = "data_sql", model_dir: str = "data_sql"):
wandb_run_name=f"openllama-{datetime.now().strftime('%Y-%m-%d-%H-%M')}",
)
# Delete scraped data after fine-tuning
os.remove(data_path)
# # Delete scraped data after fine-tuning
# os.remove(data_path)
+14 -9
View File
@@ -74,14 +74,16 @@ def run_query(query: str, model_dir: str = "data_sql", use_finetuned_model: bool
)
response = query_engine.query(query)
print(
f'Model output: \n'
f'SQL Query: {str(response.metadata["sql_query"])}'
f"Response: {response.response}"
)
return response
def print_response(response):
print(
f'*****Model output*****\n'
f'SQL Query: {str(response.metadata["sql_query"])}\n'
f"Response: {response.response}\n"
)
@stub.local_entrypoint()
def main(query: str, sqlite_file_path: str, model_dir: str = "data_sql", use_finetuned_model: str = "True"):
"""Main function."""
@@ -91,8 +93,11 @@ def main(query: str, sqlite_file_path: str, model_dir: str = "data_sql", use_fin
if use_finetuned_model == "None":
# try both
run_query.call(query, model_dir=model_dir, use_finetuned_model=True)
run_query.call(query, model_dir=model_dir, use_finetuned_model=False)
response_0 = run_query.call(query, model_dir=model_dir, use_finetuned_model=True)
print_response(response_0)
response_1 = run_query.call(query, model_dir=model_dir, use_finetuned_model=False)
print_response(response_1)
else:
bool_toggle = use_finetuned_model == "True"
run_query.call(query, model_dir=model_dir, use_finetuned_model=bool_toggle)
response = run_query.call(query, model_dir=model_dir, use_finetuned_model=bool_toggle)
print_response(response)
+2995 -1
View File
File diff suppressed because it is too large Load Diff