mirror of
https://github.com/onyx-dot-app/litellm.git
synced 2026-07-01 20:44:04 -04:00
feat: read from custom-llm-provider header (#15528)
This commit is contained in:
@@ -18,6 +18,7 @@ from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessin
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm.proxy.common_utils.openai_endpoint_utils import (
|
||||
get_custom_llm_provider_from_request_query,
|
||||
get_custom_llm_provider_from_request_headers,
|
||||
)
|
||||
from litellm.proxy.openai_files_endpoints.common_utils import (
|
||||
_is_base64_encoded_unified_file_id,
|
||||
@@ -282,6 +283,7 @@ async def retrieve_batch(
|
||||
else:
|
||||
custom_llm_provider = (
|
||||
provider
|
||||
or get_custom_llm_provider_from_request_headers(request=request)
|
||||
or get_custom_llm_provider_from_request_query(request=request)
|
||||
or "openai"
|
||||
)
|
||||
@@ -414,6 +416,7 @@ async def list_batches(
|
||||
else:
|
||||
custom_llm_provider = (
|
||||
provider
|
||||
or get_custom_llm_provider_from_request_headers(request=request)
|
||||
or get_custom_llm_provider_from_request_query(request=request)
|
||||
or "openai"
|
||||
)
|
||||
|
||||
@@ -54,3 +54,11 @@ def get_custom_llm_provider_from_request_query(request: Request) -> Optional[str
|
||||
if "custom_llm_provider" in request.query_params:
|
||||
return request.query_params["custom_llm_provider"]
|
||||
return None
|
||||
|
||||
def get_custom_llm_provider_from_request_headers(request: Request) -> Optional[str]:
|
||||
"""
|
||||
Get the `custom_llm_provider` from the request header `custom-llm-provider`
|
||||
"""
|
||||
if "custom-llm-provider" in request.headers:
|
||||
return request.headers["custom-llm-provider"]
|
||||
return None
|
||||
|
||||
@@ -32,6 +32,7 @@ from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessin
|
||||
from litellm.proxy.common_utils.openai_endpoint_utils import (
|
||||
get_custom_llm_provider_from_request_body,
|
||||
get_custom_llm_provider_from_request_query,
|
||||
get_custom_llm_provider_from_request_headers,
|
||||
)
|
||||
from litellm.proxy.utils import ProxyLogging, is_known_model
|
||||
from litellm.router import Router
|
||||
@@ -238,6 +239,7 @@ async def create_file(
|
||||
file_content = await file.read()
|
||||
custom_llm_provider = (
|
||||
provider
|
||||
or get_custom_llm_provider_from_request_headers(request=request)
|
||||
or get_custom_llm_provider_from_request_query(request=request)
|
||||
or await get_custom_llm_provider_from_request_body(request=request)
|
||||
or "openai"
|
||||
@@ -427,6 +429,7 @@ async def get_file_content(
|
||||
|
||||
custom_llm_provider = (
|
||||
provider
|
||||
or get_custom_llm_provider_from_request_headers(request=request)
|
||||
or get_custom_llm_provider_from_request_query(request=request)
|
||||
or await get_custom_llm_provider_from_request_body(request=request)
|
||||
or "openai"
|
||||
@@ -594,6 +597,7 @@ async def get_file(
|
||||
|
||||
custom_llm_provider = (
|
||||
provider
|
||||
or get_custom_llm_provider_from_request_headers(request=request)
|
||||
or get_custom_llm_provider_from_request_query(request=request)
|
||||
or await get_custom_llm_provider_from_request_body(request=request)
|
||||
or "openai"
|
||||
@@ -743,6 +747,7 @@ async def delete_file(
|
||||
try:
|
||||
custom_llm_provider = (
|
||||
provider
|
||||
or get_custom_llm_provider_from_request_headers(request=request)
|
||||
or get_custom_llm_provider_from_request_query(request=request)
|
||||
or await get_custom_llm_provider_from_request_body(request=request)
|
||||
or "openai"
|
||||
@@ -928,6 +933,7 @@ async def list_files(
|
||||
else:
|
||||
custom_llm_provider = (
|
||||
provider
|
||||
or get_custom_llm_provider_from_request_headers(request=request)
|
||||
or get_custom_llm_provider_from_request_query(request=request)
|
||||
or await get_custom_llm_provider_from_request_body(request=request)
|
||||
or "openai"
|
||||
|
||||
@@ -72,7 +72,7 @@ def create_batch_oai_sdk(filepath: str, custom_llm_provider: str) -> str:
|
||||
batch_input_file = client.files.create(
|
||||
file=open(filepath, "rb"),
|
||||
purpose="batch",
|
||||
extra_body={"custom_llm_provider": custom_llm_provider},
|
||||
extra_headers={"custom-llm-provider": custom_llm_provider},
|
||||
)
|
||||
batch_input_file_id = batch_input_file.id
|
||||
|
||||
@@ -85,7 +85,7 @@ def create_batch_oai_sdk(filepath: str, custom_llm_provider: str) -> str:
|
||||
metadata={
|
||||
"description": filepath,
|
||||
},
|
||||
extra_body={"custom_llm_provider": custom_llm_provider},
|
||||
extra_headers={"custom-llm-provider": custom_llm_provider},
|
||||
)
|
||||
|
||||
print(f"Batch submitted. ID: {rq.id}")
|
||||
@@ -98,7 +98,7 @@ def await_batch_completion(batch_id: str, custom_llm_provider: str):
|
||||
|
||||
while tries < max_tries:
|
||||
batch = client.batches.retrieve(
|
||||
batch_id, extra_body={"custom_llm_provider": custom_llm_provider}
|
||||
batch_id, extra_headers={"custom-llm-provider": custom_llm_provider}
|
||||
)
|
||||
if batch.status == "completed":
|
||||
print(f"Batch {batch_id} completed.")
|
||||
@@ -117,11 +117,11 @@ def write_content_to_file(
|
||||
batch_id: str, output_path: str, custom_llm_provider: str
|
||||
) -> str:
|
||||
batch = client.batches.retrieve(
|
||||
batch_id=batch_id, extra_body={"custom_llm_provider": custom_llm_provider}
|
||||
batch_id=batch_id, extra_headers={"custom-llm-provider": custom_llm_provider}
|
||||
)
|
||||
content = client.files.content(
|
||||
file_id=batch.output_file_id,
|
||||
extra_body={"custom_llm_provider": custom_llm_provider},
|
||||
extra_headers={"custom-llm-provider": custom_llm_provider},
|
||||
)
|
||||
print("content from files.content", content.content)
|
||||
content.write_to_file(output_path)
|
||||
@@ -144,7 +144,7 @@ def read_jsonl(filepath: str):
|
||||
|
||||
def get_any_completed_batch_id_azure():
|
||||
print("AZURE getting any completed batch id")
|
||||
list_of_batches = client.batches.list(extra_body={"custom_llm_provider": "azure"})
|
||||
list_of_batches = client.batches.list(extra_headers={"custom-llm-provider": "azure"})
|
||||
print("list of batches", list_of_batches)
|
||||
for batch in list_of_batches:
|
||||
if batch.status == "completed":
|
||||
@@ -202,7 +202,7 @@ def test_vertex_batches_endpoint():
|
||||
file_obj = oai_client.files.create(
|
||||
file=open(file_path, "rb"),
|
||||
purpose="batch",
|
||||
extra_body={"custom_llm_provider": "vertex_ai"},
|
||||
extra_headers={"custom-llm-provider": "vertex_ai"},
|
||||
)
|
||||
print("Response from creating file=", file_obj)
|
||||
|
||||
@@ -215,7 +215,7 @@ def test_vertex_batches_endpoint():
|
||||
completion_window="24h",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id=batch_input_file_id,
|
||||
extra_body={"custom_llm_provider": "vertex_ai"},
|
||||
extra_headers={"custom-llm-provider": "vertex_ai"},
|
||||
metadata={"key1": "value1", "key2": "value2"},
|
||||
)
|
||||
print("response from create batch", create_batch_response)
|
||||
|
||||
@@ -18,7 +18,7 @@ async def test_openai_fine_tuning():
|
||||
file_path = os.path.join(_current_dir, file_name)
|
||||
|
||||
response = await client.files.create(
|
||||
extra_body={"custom_llm_provider": "openai"},
|
||||
extra_headers={"custom-llm-provider": "openai"},
|
||||
file=open(file_path, "rb"),
|
||||
purpose="fine-tune",
|
||||
)
|
||||
@@ -32,7 +32,7 @@ async def test_openai_fine_tuning():
|
||||
ft_job = await client.fine_tuning.jobs.create(
|
||||
model="gpt-4o-mini-2024-07-18",
|
||||
training_file=response.id,
|
||||
extra_body={"custom_llm_provider": "openai"},
|
||||
extra_headers={"custom-llm-provider": "openai"},
|
||||
)
|
||||
|
||||
print("response from ft job={}".format(ft_job))
|
||||
@@ -42,7 +42,7 @@ async def test_openai_fine_tuning():
|
||||
|
||||
# list all fine tuning jobs
|
||||
list_ft_jobs = await client.fine_tuning.jobs.list(
|
||||
extra_query={"custom_llm_provider": "openai"}
|
||||
extra_headers={"custom-llm-provider": "openai"}
|
||||
)
|
||||
|
||||
print("list of ft jobs={}".format(list_ft_jobs))
|
||||
@@ -50,7 +50,7 @@ async def test_openai_fine_tuning():
|
||||
# cancel specific fine tuning job
|
||||
cancel_ft_job = await client.fine_tuning.jobs.cancel(
|
||||
fine_tuning_job_id=ft_job.id,
|
||||
extra_body={"custom_llm_provider": "openai"},
|
||||
extra_headers={"custom-llm-provider": "openai"},
|
||||
)
|
||||
|
||||
print("response from cancel ft job={}".format(cancel_ft_job))
|
||||
@@ -60,7 +60,7 @@ async def test_openai_fine_tuning():
|
||||
# delete OG file
|
||||
await client.files.delete(
|
||||
file_id=response.id,
|
||||
extra_body={"custom_llm_provider": "openai"},
|
||||
extra_headers={"custom-llm-provider": "openai"},
|
||||
)
|
||||
except openai.InternalServerError:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user