feat: read from custom-llm-provider header (#15528)

This commit is contained in:
Timothée Lecomte
2025-10-19 07:04:53 +02:00
committed by GitHub
parent 4a74190c12
commit 3ef9b2015a
5 changed files with 30 additions and 13 deletions
@@ -18,6 +18,7 @@ from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessin
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.common_utils.openai_endpoint_utils import (
get_custom_llm_provider_from_request_query,
get_custom_llm_provider_from_request_headers,
)
from litellm.proxy.openai_files_endpoints.common_utils import (
_is_base64_encoded_unified_file_id,
@@ -282,6 +283,7 @@ async def retrieve_batch(
else:
custom_llm_provider = (
provider
or get_custom_llm_provider_from_request_headers(request=request)
or get_custom_llm_provider_from_request_query(request=request)
or "openai"
)
@@ -414,6 +416,7 @@ async def list_batches(
else:
custom_llm_provider = (
provider
or get_custom_llm_provider_from_request_headers(request=request)
or get_custom_llm_provider_from_request_query(request=request)
or "openai"
)
@@ -54,3 +54,11 @@ def get_custom_llm_provider_from_request_query(request: Request) -> Optional[str
if "custom_llm_provider" in request.query_params:
return request.query_params["custom_llm_provider"]
return None
def get_custom_llm_provider_from_request_headers(request: Request) -> Optional[str]:
"""
Get the `custom_llm_provider` from the request header `custom-llm-provider`
"""
if "custom-llm-provider" in request.headers:
return request.headers["custom-llm-provider"]
return None
@@ -32,6 +32,7 @@ from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessin
from litellm.proxy.common_utils.openai_endpoint_utils import (
get_custom_llm_provider_from_request_body,
get_custom_llm_provider_from_request_query,
get_custom_llm_provider_from_request_headers,
)
from litellm.proxy.utils import ProxyLogging, is_known_model
from litellm.router import Router
@@ -238,6 +239,7 @@ async def create_file(
file_content = await file.read()
custom_llm_provider = (
provider
or get_custom_llm_provider_from_request_headers(request=request)
or get_custom_llm_provider_from_request_query(request=request)
or await get_custom_llm_provider_from_request_body(request=request)
or "openai"
@@ -427,6 +429,7 @@ async def get_file_content(
custom_llm_provider = (
provider
or get_custom_llm_provider_from_request_headers(request=request)
or get_custom_llm_provider_from_request_query(request=request)
or await get_custom_llm_provider_from_request_body(request=request)
or "openai"
@@ -594,6 +597,7 @@ async def get_file(
custom_llm_provider = (
provider
or get_custom_llm_provider_from_request_headers(request=request)
or get_custom_llm_provider_from_request_query(request=request)
or await get_custom_llm_provider_from_request_body(request=request)
or "openai"
@@ -743,6 +747,7 @@ async def delete_file(
try:
custom_llm_provider = (
provider
or get_custom_llm_provider_from_request_headers(request=request)
or get_custom_llm_provider_from_request_query(request=request)
or await get_custom_llm_provider_from_request_body(request=request)
or "openai"
@@ -928,6 +933,7 @@ async def list_files(
else:
custom_llm_provider = (
provider
or get_custom_llm_provider_from_request_headers(request=request)
or get_custom_llm_provider_from_request_query(request=request)
or await get_custom_llm_provider_from_request_body(request=request)
or "openai"
@@ -72,7 +72,7 @@ def create_batch_oai_sdk(filepath: str, custom_llm_provider: str) -> str:
batch_input_file = client.files.create(
file=open(filepath, "rb"),
purpose="batch",
extra_body={"custom_llm_provider": custom_llm_provider},
extra_headers={"custom-llm-provider": custom_llm_provider},
)
batch_input_file_id = batch_input_file.id
@@ -85,7 +85,7 @@ def create_batch_oai_sdk(filepath: str, custom_llm_provider: str) -> str:
metadata={
"description": filepath,
},
extra_body={"custom_llm_provider": custom_llm_provider},
extra_headers={"custom-llm-provider": custom_llm_provider},
)
print(f"Batch submitted. ID: {rq.id}")
@@ -98,7 +98,7 @@ def await_batch_completion(batch_id: str, custom_llm_provider: str):
while tries < max_tries:
batch = client.batches.retrieve(
batch_id, extra_body={"custom_llm_provider": custom_llm_provider}
batch_id, extra_headers={"custom-llm-provider": custom_llm_provider}
)
if batch.status == "completed":
print(f"Batch {batch_id} completed.")
@@ -117,11 +117,11 @@ def write_content_to_file(
batch_id: str, output_path: str, custom_llm_provider: str
) -> str:
batch = client.batches.retrieve(
batch_id=batch_id, extra_body={"custom_llm_provider": custom_llm_provider}
batch_id=batch_id, extra_headers={"custom-llm-provider": custom_llm_provider}
)
content = client.files.content(
file_id=batch.output_file_id,
extra_body={"custom_llm_provider": custom_llm_provider},
extra_headers={"custom-llm-provider": custom_llm_provider},
)
print("content from files.content", content.content)
content.write_to_file(output_path)
@@ -144,7 +144,7 @@ def read_jsonl(filepath: str):
def get_any_completed_batch_id_azure():
print("AZURE getting any completed batch id")
list_of_batches = client.batches.list(extra_body={"custom_llm_provider": "azure"})
list_of_batches = client.batches.list(extra_headers={"custom-llm-provider": "azure"})
print("list of batches", list_of_batches)
for batch in list_of_batches:
if batch.status == "completed":
@@ -202,7 +202,7 @@ def test_vertex_batches_endpoint():
file_obj = oai_client.files.create(
file=open(file_path, "rb"),
purpose="batch",
extra_body={"custom_llm_provider": "vertex_ai"},
extra_headers={"custom-llm-provider": "vertex_ai"},
)
print("Response from creating file=", file_obj)
@@ -215,7 +215,7 @@ def test_vertex_batches_endpoint():
completion_window="24h",
endpoint="/v1/chat/completions",
input_file_id=batch_input_file_id,
extra_body={"custom_llm_provider": "vertex_ai"},
extra_headers={"custom-llm-provider": "vertex_ai"},
metadata={"key1": "value1", "key2": "value2"},
)
print("response from create batch", create_batch_response)
@@ -18,7 +18,7 @@ async def test_openai_fine_tuning():
file_path = os.path.join(_current_dir, file_name)
response = await client.files.create(
extra_body={"custom_llm_provider": "openai"},
extra_headers={"custom-llm-provider": "openai"},
file=open(file_path, "rb"),
purpose="fine-tune",
)
@@ -32,7 +32,7 @@ async def test_openai_fine_tuning():
ft_job = await client.fine_tuning.jobs.create(
model="gpt-4o-mini-2024-07-18",
training_file=response.id,
extra_body={"custom_llm_provider": "openai"},
extra_headers={"custom-llm-provider": "openai"},
)
print("response from ft job={}".format(ft_job))
@@ -42,7 +42,7 @@ async def test_openai_fine_tuning():
# list all fine tuning jobs
list_ft_jobs = await client.fine_tuning.jobs.list(
extra_query={"custom_llm_provider": "openai"}
extra_headers={"custom-llm-provider": "openai"}
)
print("list of ft jobs={}".format(list_ft_jobs))
@@ -50,7 +50,7 @@ async def test_openai_fine_tuning():
# cancel specific fine tuning job
cancel_ft_job = await client.fine_tuning.jobs.cancel(
fine_tuning_job_id=ft_job.id,
extra_body={"custom_llm_provider": "openai"},
extra_headers={"custom-llm-provider": "openai"},
)
print("response from cancel ft job={}".format(cancel_ft_job))
@@ -60,7 +60,7 @@ async def test_openai_fine_tuning():
# delete OG file
await client.files.delete(
file_id=response.id,
extra_body={"custom_llm_provider": "openai"},
extra_headers={"custom-llm-provider": "openai"},
)
except openai.InternalServerError:
pass