mirror of
https://github.com/run-llama/modal_finetune_sql.git
synced 2026-06-30 21:47:58 -04:00
add fixes (#2)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# Finetuning LLaMa + Text-to-SQL
|
||||
|
||||
This walkthrough shows you how to fine-tune LLaMa-7B on a Text-to-SQL dataset, and then use it for inference against
|
||||
This walkthrough shows you how to fine-tune LLaMa 2 7B on a Text-to-SQL dataset, and then use it for inference against
|
||||
any database of structured data using LlamaIndex.
|
||||
|
||||
|
||||
|
||||
+1
-1
@@ -13,7 +13,7 @@ MODEL_PATH = "/model"
|
||||
def download_models():
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
model_name = "openlm-research/open_llama_7b"
|
||||
model_name = "openlm-research/open_llama_7b_v2"
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained(model_name)
|
||||
model.save_pretrained(MODEL_PATH)
|
||||
|
||||
+2
-6
@@ -1,8 +1,4 @@
|
||||
from typing import Optional
|
||||
|
||||
from modal import gpu, method, Retries
|
||||
from modal.cls import ClsMixin
|
||||
import json
|
||||
from modal import Retries
|
||||
|
||||
from .common import (
|
||||
output_vol,
|
||||
@@ -102,7 +98,7 @@ def main(data_dir: str = "data_sql", model_dir: str = "data_sql", num_samples: i
|
||||
num_samples=num_samples
|
||||
)
|
||||
for idx, (row_dict, completion) in enumerate(inputs_outputs_0):
|
||||
print('************ Row {idx} ************')
|
||||
print(f'************ Row {idx} ************')
|
||||
print(f"Input {idx}: " + str(row_dict))
|
||||
print(f"Output {idx} (finetuned model): " + str(completion))
|
||||
print(f"Output {idx} (base model): " + str(input_outputs_1[idx][1]))
|
||||
|
||||
+2
-2
@@ -247,5 +247,5 @@ def finetune(data_dir: str = "data_sql", model_dir: str = "data_sql"):
|
||||
wandb_run_name=f"openllama-{datetime.now().strftime('%Y-%m-%d-%H-%M')}",
|
||||
)
|
||||
|
||||
# Delete scraped data after fine-tuning
|
||||
os.remove(data_path)
|
||||
# # Delete scraped data after fine-tuning
|
||||
# os.remove(data_path)
|
||||
|
||||
@@ -74,14 +74,16 @@ def run_query(query: str, model_dir: str = "data_sql", use_finetuned_model: bool
|
||||
)
|
||||
response = query_engine.query(query)
|
||||
|
||||
print(
|
||||
f'Model output: \n'
|
||||
f'SQL Query: {str(response.metadata["sql_query"])}'
|
||||
f"Response: {response.response}"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def print_response(response):
|
||||
print(
|
||||
f'*****Model output*****\n'
|
||||
f'SQL Query: {str(response.metadata["sql_query"])}\n'
|
||||
f"Response: {response.response}\n"
|
||||
)
|
||||
|
||||
@stub.local_entrypoint()
|
||||
def main(query: str, sqlite_file_path: str, model_dir: str = "data_sql", use_finetuned_model: str = "True"):
|
||||
"""Main function."""
|
||||
@@ -91,8 +93,11 @@ def main(query: str, sqlite_file_path: str, model_dir: str = "data_sql", use_fin
|
||||
|
||||
if use_finetuned_model == "None":
|
||||
# try both
|
||||
run_query.call(query, model_dir=model_dir, use_finetuned_model=True)
|
||||
run_query.call(query, model_dir=model_dir, use_finetuned_model=False)
|
||||
response_0 = run_query.call(query, model_dir=model_dir, use_finetuned_model=True)
|
||||
print_response(response_0)
|
||||
response_1 = run_query.call(query, model_dir=model_dir, use_finetuned_model=False)
|
||||
print_response(response_1)
|
||||
else:
|
||||
bool_toggle = use_finetuned_model == "True"
|
||||
run_query.call(query, model_dir=model_dir, use_finetuned_model=bool_toggle)
|
||||
response = run_query.call(query, model_dir=model_dir, use_finetuned_model=bool_toggle)
|
||||
print_response(response)
|
||||
|
||||
+2995
-1
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user