Merge pull request #3 from jjshanks/bulk-worker-param

Allow max workers to be set via a parameter
This commit is contained in:
Shreya Shankar
2025-05-22 17:59:51 -07:00
committed by GitHub
+4 -3
View File
@@ -61,7 +61,7 @@ def process_query_sync(query_id: str, query: str) -> Tuple[str, str, str]:
# Renamed and made sync
def run_bulk_test(csv_path: Path) -> None:
def run_bulk_test(csv_path: Path, num_workers: int = MAX_WORKERS) -> None:
"""Main entry point for bulk testing (synchronous version)."""
with csv_path.open("r", newline="", encoding="utf-8") as csv_file:
@@ -76,7 +76,7 @@ def run_bulk_test(csv_path: Path) -> None:
console = Console()
results_data: List[Tuple[str, str, str]] = [] # Will store (id, query, response)
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
with ThreadPoolExecutor(max_workers=num_workers) as executor:
future_to_data = {
executor.submit(process_query_sync, item["id"], item["query"]):
item for item in input_data
@@ -130,5 +130,6 @@ def run_bulk_test(csv_path: Path) -> None:
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Bulk test the recipe chatbot")
parser.add_argument("--csv", type=Path, default=DEFAULT_CSV, help="Path to CSV file containing queries (column name: 'query').")
parser.add_argument("--workers", type=int, default=MAX_WORKERS, help=f"Number of worker threads (default: {MAX_WORKERS}).")
args = parser.parse_args()
run_bulk_test(args.csv)
run_bulk_test(args.csv, args.workers)