mirror of
https://github.com/langchain-ai/recipe-chatbot.git
synced 2026-07-01 20:04:01 -04:00
Merge pull request #3 from jjshanks/bulk-worker-param
Allow max workers to be set via a parameter
This commit is contained in:
@@ -61,7 +61,7 @@ def process_query_sync(query_id: str, query: str) -> Tuple[str, str, str]:
|
||||
|
||||
|
||||
# Renamed and made sync
|
||||
def run_bulk_test(csv_path: Path) -> None:
|
||||
def run_bulk_test(csv_path: Path, num_workers: int = MAX_WORKERS) -> None:
|
||||
"""Main entry point for bulk testing (synchronous version)."""
|
||||
|
||||
with csv_path.open("r", newline="", encoding="utf-8") as csv_file:
|
||||
@@ -76,7 +76,7 @@ def run_bulk_test(csv_path: Path) -> None:
|
||||
|
||||
console = Console()
|
||||
results_data: List[Tuple[str, str, str]] = [] # Will store (id, query, response)
|
||||
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
future_to_data = {
|
||||
executor.submit(process_query_sync, item["id"], item["query"]):
|
||||
item for item in input_data
|
||||
@@ -130,5 +130,6 @@ def run_bulk_test(csv_path: Path) -> None:
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Bulk test the recipe chatbot")
|
||||
parser.add_argument("--csv", type=Path, default=DEFAULT_CSV, help="Path to CSV file containing queries (column name: 'query').")
|
||||
parser.add_argument("--workers", type=int, default=MAX_WORKERS, help=f"Number of worker threads (default: {MAX_WORKERS}).")
|
||||
args = parser.parse_args()
|
||||
run_bulk_test(args.csv)
|
||||
run_bulk_test(args.csv, args.workers)
|
||||
|
||||
Reference in New Issue
Block a user