mirror of
https://github.com/onyx-dot-app/litellm.git
synced 2026-07-01 20:44:04 -04:00
[Feat] Tag Management - Add support for setting tag based budgets (#15433)
* feat: add LiteLLM_TagTable * fix: use new table for tag management * fix - allow setting budgets for tags * working tag creation * fix schema.prisma * add tag info * ui fixes * ui fix tag info * TAG_CACHE_IN_MEMORY_TTL_SECONDS * add Litellm_EntityType * fix get_aggregated_db_spend_update_transactions * fix: _update_entity_spend_in_db * fix _tag_max_budget_check * add tag budget check * add tag_list_transactions * test_get_tag_objects_batch * test_update_tag_db_without_prisma_client * fix get_tags_from_request_body * get_tags_from_request_body * fix get_tags_from_request_body * fix spend tracking utils * get_tags_from_request_body * test_get_tags_from_request_body_with_metadata_tags * feat: add _update_tag_cache spend tracking * fix _PROXY_track_cost_callback * test_tag_cache_update_multiple_tags * fix tag info * docs fix * docs tag budgets * doc fix * docs fix * fix tag budget * docs tag budgets * docs fix * ruff fix
This commit is contained in:
@@ -0,0 +1,277 @@
|
||||
import Image from '@theme/IdealImage';
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Setting Tag Budgets
|
||||
|
||||
Track spend and set budgets for your API requests using tags. Tags allow you to categorize and monitor costs across different cost centers, projects, and departments.
|
||||
|
||||
## Pre-Requisites
|
||||
|
||||
- You must set up a Postgres database (e.g. Supabase, Neon, etc.)
|
||||
|
||||
## What are Tags?
|
||||
|
||||
Tags are labels you can attach to your LLM requests to track and limit spending by category.
|
||||
|
||||
**Common Use Cases:**
|
||||
- **Cost Center Tracking**: Allocate LLM costs to specific departments or business units (e.g., "engineering", "marketing", "customer-support")
|
||||
- **Project-based Budgeting**: Set budgets for different projects or initiatives (e.g., "project-alpha", "chatbot-v2")
|
||||
- **Customer Attribution**: Track spend per customer or client (e.g., "customer-acme", "customer-techcorp")
|
||||
- **Feature Monitoring**: Monitor costs for specific features (e.g., "feature-chat", "feature-summarization")
|
||||
|
||||
Tags are added to each request in the `metadata` field to track and enforce budget limits.
|
||||
|
||||
## Setting Tag Budgets
|
||||
|
||||
### 1. Create a tag with budget
|
||||
|
||||
Create a tag to represent a cost center, project, or any budget category. Set `max_budget` ($ value allowed) and `budget_duration` (how frequently the budget resets).
|
||||
|
||||
**Example:** Create a tag for your Engineering department with a monthly $500 budget
|
||||
|
||||
#### API
|
||||
|
||||
Create a new tag and set `max_budget` and `budget_duration`
|
||||
|
||||
```shell
|
||||
curl -X POST 'http://0.0.0.0:4000/tag/new' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"name": "engineering",
|
||||
"description": "Engineering department cost center",
|
||||
"max_budget": 500.0,
|
||||
"budget_duration": "30d"
|
||||
}'
|
||||
```
|
||||
|
||||
**Request Body Parameters:**
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
|-----------|------|----------|-------------|
|
||||
| `name` | string | Yes | Unique name for the tag (e.g., cost center name) |
|
||||
| `description` | string | No | Description of what this tag tracks |
|
||||
| `models` | list[string] | No | Restrict tag to specific models |
|
||||
| `max_budget` | float | No | Maximum budget in USD |
|
||||
| `budget_duration` | string | No | How often budget resets (e.g., "30d", "1d") |
|
||||
| `soft_budget` | float | No | Soft budget limit for warnings |
|
||||
|
||||
**Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "engineering",
|
||||
"description": "Engineering department cost center",
|
||||
"max_budget": 500.0,
|
||||
"budget_duration": "30d",
|
||||
"budget_reset_at": "2025-11-10T00:00:00Z",
|
||||
"created_at": "2025-10-11T00:00:00Z"
|
||||
}
|
||||
```
|
||||
|
||||
#### LiteLLM Admin UI
|
||||
|
||||
Navigate to the **Tag Management** page and click **Create New Tag**. Fill in the tag details and set your budget:
|
||||
|
||||
<Image
|
||||
img={require('../../img/tag_budget1.png')}
|
||||
style={{width: '80%', display: 'block', margin: '0'}}
|
||||
/>
|
||||
|
||||
<br />
|
||||
|
||||
|
||||
**Possible values for `budget_duration`:**
|
||||
|
||||
| `budget_duration` | When Budget will reset |
|
||||
| --- | --- |
|
||||
| `budget_duration="1s"` | every 1 second |
|
||||
| `budget_duration="1m"` | every 1 minute |
|
||||
| `budget_duration="1h"` | every 1 hour |
|
||||
| `budget_duration="1d"` | every 1 day |
|
||||
| `budget_duration="7d"` | every 1 week |
|
||||
| `budget_duration="30d"` | every 1 month |
|
||||
|
||||
### 2. Use the tag in your requests
|
||||
|
||||
Add tags to your API requests in the `metadata` field:
|
||||
|
||||
:::info Tags Budgets on API Keys
|
||||
|
||||
Currently, tag budget enforcement is only supported per request. If you'd like to set tags on API keys so all requests automatically inherit the tags budgets, please [create a feature request on GitHub](https://github.com/BerriAI/litellm/issues/new?assignees=&labels=enhancement&projects=&template=feature_request.yml&title=%5BFeat%5D%3A).
|
||||
|
||||
:::
|
||||
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="openai" label="OpenAI SDK">
|
||||
|
||||
```python
|
||||
import openai
|
||||
|
||||
client = openai.OpenAI(
|
||||
api_key="sk-1234", # Your LiteLLM proxy key
|
||||
base_url="http://0.0.0.0:4000"
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
extra_body={
|
||||
"metadata": {
|
||||
"tags": ["engineering"]
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="curl" label="cURL">
|
||||
|
||||
```shell
|
||||
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"metadata": {
|
||||
"tags": ["engineering"]
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
### 3. Test It
|
||||
|
||||
Make requests until the budget is exceeded:
|
||||
|
||||
```shell
|
||||
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"metadata": {
|
||||
"tags": ["engineering"]
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
**When budget is exceeded, you'll see:**
|
||||
|
||||
```json
|
||||
{
|
||||
"error": {
|
||||
"message": "Budget has been exceeded! Tag=engineering Current cost: 505.50, Max budget: 500.0",
|
||||
"type": "budget_exceeded",
|
||||
"param": null,
|
||||
"code": "400"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Managing Tags
|
||||
|
||||
### View Tag Information
|
||||
|
||||
Get information about specific tags:
|
||||
|
||||
```shell
|
||||
curl -X POST 'http://0.0.0.0:4000/tag/info' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"names": ["engineering", "marketing"]
|
||||
}'
|
||||
```
|
||||
|
||||
**Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
"engineering": {
|
||||
"name": "engineering",
|
||||
"description": "Engineering department cost center",
|
||||
"spend": 245.50,
|
||||
"max_budget": 500.0,
|
||||
"budget_duration": "30d",
|
||||
"budget_reset_at": "2025-11-10T00:00:00Z",
|
||||
"created_at": "2025-10-11T00:00:00Z",
|
||||
"updated_at": "2025-10-11T12:30:00Z"
|
||||
},
|
||||
"marketing": {
|
||||
"name": "marketing",
|
||||
"description": "Marketing department cost center",
|
||||
"spend": 89.20,
|
||||
"max_budget": 300.0,
|
||||
"budget_duration": "30d",
|
||||
"budget_reset_at": "2025-11-10T00:00:00Z",
|
||||
"created_at": "2025-10-11T00:00:00Z",
|
||||
"updated_at": "2025-10-11T12:30:00Z"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Update Tag Budget
|
||||
|
||||
Update an existing tag's budget:
|
||||
|
||||
```shell
|
||||
curl -X POST 'http://0.0.0.0:4000/tag/update' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"name": "engineering",
|
||||
"max_budget": 750.0,
|
||||
"budget_duration": "30d"
|
||||
}'
|
||||
```
|
||||
|
||||
### Delete Tag
|
||||
|
||||
```shell
|
||||
curl -X POST 'http://0.0.0.0:4000/tag/delete' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"name": "engineering"
|
||||
}'
|
||||
```
|
||||
|
||||
## Multiple Tags per Request
|
||||
|
||||
You can apply multiple tags to a single request to track costs across different dimensions simultaneously. For example, track both the cost center and the specific project:
|
||||
|
||||
```python
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
extra_body={
|
||||
"metadata": {
|
||||
"tags": ["engineering", "project-alpha", "customer-acme"]
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
```shell
|
||||
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"metadata": {
|
||||
"tags": ["engineering", "project-alpha", "customer-acme"]
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
**Budget Enforcement:** If any tag exceeds its budget, the request will be rejected.
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 305 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 416 KiB |
@@ -189,12 +189,13 @@ const sidebars = {
|
||||
type: "category",
|
||||
label: "Budgets + Rate Limits",
|
||||
items: [
|
||||
"proxy/users",
|
||||
"proxy/team_budgets",
|
||||
"proxy/tag_budgets",
|
||||
"proxy/customers",
|
||||
"proxy/dynamic_rate_limit",
|
||||
"proxy/rate_limit_tiers",
|
||||
"proxy/team_budgets",
|
||||
"proxy/temporary_budget_increase",
|
||||
"proxy/users"
|
||||
],
|
||||
},
|
||||
"proxy/caching",
|
||||
|
||||
@@ -25,6 +25,7 @@ model LiteLLM_BudgetTable {
|
||||
organization LiteLLM_OrganizationTable[] // multiple orgs can have the same budget
|
||||
keys LiteLLM_VerificationToken[] // multiple keys can have the same budget
|
||||
end_users LiteLLM_EndUserTable[] // multiple end-users can have the same budget
|
||||
tags LiteLLM_TagTable[] // multiple tags can have the same budget
|
||||
team_membership LiteLLM_TeamMembership[] // budgets of Users within a Team
|
||||
organization_membership LiteLLM_OrganizationMembership[] // budgets of Users within a Organization
|
||||
}
|
||||
@@ -245,6 +246,20 @@ model LiteLLM_EndUserTable {
|
||||
blocked Boolean @default(false)
|
||||
}
|
||||
|
||||
// Track tags with budgets and spend
|
||||
model LiteLLM_TagTable {
|
||||
tag_name String @id
|
||||
description String?
|
||||
models String[]
|
||||
model_info Json? // maps model_id to model_name
|
||||
spend Float @default(0.0)
|
||||
budget_id String?
|
||||
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
|
||||
created_at DateTime @default(now()) @map("created_at")
|
||||
created_by String?
|
||||
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||
}
|
||||
|
||||
// store proxy config.yaml
|
||||
model LiteLLM_Config {
|
||||
param_name String @id
|
||||
|
||||
@@ -185,6 +185,7 @@ class Litellm_EntityType(enum.Enum):
|
||||
TEAM = "team"
|
||||
TEAM_MEMBER = "team_member"
|
||||
ORGANIZATION = "organization"
|
||||
TAG = "tag"
|
||||
|
||||
# global proxy level entity
|
||||
PROXY = "proxy"
|
||||
@@ -2143,6 +2144,30 @@ class LiteLLM_EndUserTable(LiteLLMPydanticObjectBase):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class LiteLLM_TagTable(LiteLLMPydanticObjectBase):
|
||||
tag_name: str
|
||||
description: Optional[str] = None
|
||||
models: List[str] = []
|
||||
model_info: Optional[dict] = None
|
||||
spend: float = 0.0
|
||||
budget_id: Optional[str] = None
|
||||
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
|
||||
created_at: Optional[datetime] = None
|
||||
created_by: Optional[str] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_model_info(cls, values):
|
||||
if values.get("spend") is None:
|
||||
values.update({"spend": 0.0})
|
||||
if values.get("models") is None:
|
||||
values.update({"models": []})
|
||||
return values
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class LiteLLM_SpendLogs(LiteLLMPydanticObjectBase):
|
||||
request_id: str
|
||||
api_key: str
|
||||
@@ -3419,6 +3444,7 @@ class DBSpendUpdateTransactions(TypedDict):
|
||||
team_list_transactions: Optional[Dict[str, float]]
|
||||
team_member_list_transactions: Optional[Dict[str, float]]
|
||||
org_list_transactions: Optional[Dict[str, float]]
|
||||
tag_list_transactions: Optional[Dict[str, float]]
|
||||
|
||||
|
||||
class SpendUpdateQueueItem(TypedDict, total=False):
|
||||
|
||||
@@ -35,6 +35,7 @@ from litellm.proxy._types import (
|
||||
LiteLLM_ObjectPermissionTable,
|
||||
LiteLLM_OrganizationMembershipTable,
|
||||
LiteLLM_OrganizationTable,
|
||||
LiteLLM_TagTable,
|
||||
LiteLLM_TeamMembership,
|
||||
LiteLLM_TeamTable,
|
||||
LiteLLM_TeamTableCachedObj,
|
||||
@@ -99,6 +100,8 @@ async def common_checks(
|
||||
10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks
|
||||
11. [OPTIONAL] Vector store checks - is the object allowed to access the vector store
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client, user_api_key_cache
|
||||
|
||||
_model: Optional[Union[str, List[str]]] = get_model_from_request(
|
||||
request_body, route
|
||||
)
|
||||
@@ -139,6 +142,14 @@ async def common_checks(
|
||||
valid_token=valid_token,
|
||||
)
|
||||
|
||||
await _tag_max_budget_check(
|
||||
request_body=request_body,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
valid_token=valid_token,
|
||||
)
|
||||
|
||||
# 4. If user is in budget
|
||||
## 4.1 check personal budget, if personal key
|
||||
if (
|
||||
@@ -499,6 +510,115 @@ async def get_end_user_object(
|
||||
return None
|
||||
|
||||
|
||||
@log_db_metrics
|
||||
async def get_tag_objects_batch(
|
||||
tag_names: List[str],
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
proxy_logging_obj: Optional[ProxyLogging] = None,
|
||||
) -> Dict[str, LiteLLM_TagTable]:
|
||||
"""
|
||||
Batch fetch multiple tag objects from cache and db.
|
||||
|
||||
Optimizes for latency by:
|
||||
1. Fetching all cached tags in parallel
|
||||
2. Batch fetching uncached tags in one DB query
|
||||
|
||||
Args:
|
||||
tag_names: List of tag names to fetch
|
||||
prisma_client: Prisma database client
|
||||
user_api_key_cache: Cache for storing tag objects
|
||||
parent_otel_span: Optional OpenTelemetry span for tracing
|
||||
proxy_logging_obj: Optional proxy logging object
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tag_name to LiteLLM_TagTable object
|
||||
"""
|
||||
if prisma_client is None:
|
||||
return {}
|
||||
|
||||
if not tag_names:
|
||||
return {}
|
||||
|
||||
tag_objects = {}
|
||||
uncached_tags = []
|
||||
|
||||
# Try to get all tags from cache first
|
||||
for tag_name in tag_names:
|
||||
cache_key = f"tag:{tag_name}"
|
||||
cached_tag = await user_api_key_cache.async_get_cache(key=cache_key)
|
||||
if cached_tag is not None:
|
||||
if isinstance(cached_tag, dict):
|
||||
tag_objects[tag_name] = LiteLLM_TagTable(**cached_tag)
|
||||
else:
|
||||
tag_objects[tag_name] = cached_tag
|
||||
else:
|
||||
uncached_tags.append(tag_name)
|
||||
|
||||
# Batch fetch uncached tags from DB in one query
|
||||
if uncached_tags:
|
||||
try:
|
||||
db_tags = await prisma_client.db.litellm_tagtable.find_many(
|
||||
where={"tag_name": {"in": uncached_tags}},
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
|
||||
# Cache and add to tag_objects
|
||||
for db_tag in db_tags:
|
||||
tag_name = db_tag.tag_name
|
||||
cache_key = f"tag:{tag_name}"
|
||||
# Cache with default TTL (same as end_user objects)
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=cache_key, value=db_tag.dict()
|
||||
)
|
||||
tag_objects[tag_name] = LiteLLM_TagTable(**db_tag.dict())
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Error batch fetching tags from database: {e}"
|
||||
)
|
||||
|
||||
return tag_objects
|
||||
|
||||
|
||||
@log_db_metrics
|
||||
async def get_tag_object(
|
||||
tag_name: Optional[str],
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
proxy_logging_obj: Optional[ProxyLogging] = None,
|
||||
) -> Optional[LiteLLM_TagTable]:
|
||||
"""
|
||||
Returns tag object from cache or db.
|
||||
|
||||
Uses default cache TTL (same as end_user objects) to avoid drift.
|
||||
|
||||
Args:
|
||||
tag_name: Name of the tag to fetch
|
||||
prisma_client: Prisma database client
|
||||
user_api_key_cache: Cache for storing tag objects
|
||||
parent_otel_span: Optional OpenTelemetry span for tracing
|
||||
proxy_logging_obj: Optional proxy logging object
|
||||
|
||||
Returns:
|
||||
LiteLLM_TagTable object if found, None otherwise
|
||||
"""
|
||||
if prisma_client is None or tag_name is None:
|
||||
return None
|
||||
|
||||
# Use batch helper for consistency
|
||||
tag_objects = await get_tag_objects_batch(
|
||||
tag_names=[tag_name],
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
return tag_objects.get(tag_name)
|
||||
|
||||
|
||||
@log_db_metrics
|
||||
async def get_team_membership(
|
||||
user_id: str,
|
||||
@@ -1642,6 +1762,60 @@ async def _team_max_budget_check(
|
||||
)
|
||||
|
||||
|
||||
async def _tag_max_budget_check(
|
||||
request_body: dict,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
valid_token: Optional[UserAPIKeyAuth],
|
||||
):
|
||||
"""
|
||||
Check if any tags in the request are over their max budget.
|
||||
|
||||
Raises:
|
||||
BudgetExceededError if any tag is over its max budget.
|
||||
Triggers a budget alert if any tag is over its max budget.
|
||||
"""
|
||||
from litellm.proxy.common_utils.http_parsing_utils import (
|
||||
get_tags_from_request_body,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
return
|
||||
|
||||
# Get tags from request metadata
|
||||
tags = get_tags_from_request_body(request_body=request_body)
|
||||
if not tags:
|
||||
return
|
||||
|
||||
# Batch fetch all tags in one go
|
||||
tag_objects = await get_tag_objects_batch(
|
||||
tag_names=tags,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
# Check budget for each tag
|
||||
for tag_name in tags:
|
||||
tag_object = tag_objects.get(tag_name)
|
||||
if tag_object is None:
|
||||
continue
|
||||
|
||||
# Check if tag has budget limits
|
||||
if (
|
||||
tag_object.litellm_budget_table is not None
|
||||
and tag_object.litellm_budget_table.max_budget is not None
|
||||
and tag_object.spend is not None
|
||||
and tag_object.spend > tag_object.litellm_budget_table.max_budget
|
||||
):
|
||||
raise litellm.BudgetExceededError(
|
||||
current_cost=tag_object.spend,
|
||||
max_budget=tag_object.litellm_budget_table.max_budget,
|
||||
message=f"Budget has been exceeded! Tag={tag_name} Current cost: {tag_object.spend}, Max budget: {tag_object.litellm_budget_table.max_budget}",
|
||||
)
|
||||
|
||||
|
||||
def is_model_allowed_by_pattern(model: str, allowed_model_pattern: str) -> bool:
|
||||
"""
|
||||
Check if a model matches an allowed pattern.
|
||||
|
||||
@@ -7,6 +7,9 @@ from fastapi import Request, UploadFile, status
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import ProxyException
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
get_metadata_variable_name_from_kwargs,
|
||||
)
|
||||
from litellm.types.router import Deployment
|
||||
|
||||
|
||||
@@ -245,4 +248,23 @@ async def get_request_body(request: Request) -> Dict[str, Any]:
|
||||
raise ValueError(
|
||||
f"Unsupported content type: {request.headers.get('content-type')}"
|
||||
)
|
||||
return {}
|
||||
return {}
|
||||
|
||||
|
||||
def get_tags_from_request_body(request_body: dict) -> List[str]:
|
||||
"""
|
||||
Extract tags from request body metadata.
|
||||
|
||||
Args:
|
||||
request_body: The request body dictionary
|
||||
|
||||
Returns:
|
||||
List of tag names (strings), empty list if no valid tags found
|
||||
"""
|
||||
metadata_variable_name = get_metadata_variable_name_from_kwargs(request_body)
|
||||
metadata = request_body.get(metadata_variable_name, {})
|
||||
tags_in_metadata: List[str] = metadata.get("tags", [])
|
||||
tags_in_request_body: List[str] = request_body.get("tags", [])
|
||||
combined_tags: List[str] = tags_in_metadata + tags_in_request_body
|
||||
return [tag for tag in combined_tags if isinstance(tag, str)]
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching import DualCache, RedisCache
|
||||
from litellm.constants import DB_SPEND_UPDATE_JOB_NAME
|
||||
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
|
||||
from litellm.proxy._types import (
|
||||
DB_CONNECTION_ERROR_TYPES,
|
||||
BaseDailySpendTransaction,
|
||||
@@ -147,6 +148,13 @@ class DBSpendUpdateWriter:
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
)
|
||||
asyncio.create_task(
|
||||
self._update_tag_db(
|
||||
response_cost=response_cost,
|
||||
request_tags=payload.get("request_tags"),
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
)
|
||||
|
||||
if disable_spend_logs is False:
|
||||
await self._insert_spend_log_to_db(
|
||||
@@ -325,6 +333,54 @@ class DBSpendUpdateWriter:
|
||||
)
|
||||
raise e
|
||||
|
||||
async def _update_tag_db(
|
||||
self,
|
||||
response_cost: Optional[float],
|
||||
request_tags: Optional[str],
|
||||
prisma_client: Optional[PrismaClient],
|
||||
):
|
||||
"""
|
||||
Update spend for all tags in the request.
|
||||
|
||||
Args:
|
||||
response_cost: Cost of the request
|
||||
request_tags: JSON string of tags list e.g. '["prod-tag", "test-tag"]'
|
||||
prisma_client: Prisma client instance
|
||||
"""
|
||||
try:
|
||||
if request_tags is None or prisma_client is None:
|
||||
return
|
||||
|
||||
# Parse tags from JSON string
|
||||
tags = []
|
||||
if isinstance(request_tags, str):
|
||||
tags = safe_json_loads(request_tags, default=[])
|
||||
if not tags:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Failed to parse request_tags JSON: {request_tags}"
|
||||
)
|
||||
return
|
||||
elif isinstance(request_tags, list):
|
||||
tags = request_tags
|
||||
else:
|
||||
return
|
||||
|
||||
# Update spend for each tag
|
||||
for tag_name in tags:
|
||||
if tag_name and isinstance(tag_name, str):
|
||||
await self.spend_update_queue.add_update(
|
||||
update=SpendUpdateQueueItem(
|
||||
entity_type=Litellm_EntityType.TAG,
|
||||
entity_id=tag_name,
|
||||
response_cost=response_cost,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Update Tag DB failed to execute - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
raise e
|
||||
|
||||
async def _insert_spend_log_to_db(
|
||||
self,
|
||||
payload: Union[dict, SpendLogsPayload],
|
||||
@@ -775,6 +831,75 @@ class DBSpendUpdateWriter:
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
|
||||
### UPDATE TAG TABLE ###
|
||||
tag_list_transactions = db_spend_update_transactions["tag_list_transactions"]
|
||||
await DBSpendUpdateWriter._update_entity_spend_in_db(
|
||||
entity_name="Tag",
|
||||
transactions=tag_list_transactions,
|
||||
table_accessor="litellm_tagtable",
|
||||
where_field="tag_name",
|
||||
n_retry_times=n_retry_times,
|
||||
prisma_client=prisma_client,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _update_entity_spend_in_db(
|
||||
entity_name: str,
|
||||
transactions: Optional[Dict[str, float]],
|
||||
table_accessor: Any,
|
||||
where_field: str,
|
||||
n_retry_times: int,
|
||||
prisma_client: PrismaClient,
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
):
|
||||
"""
|
||||
Helper function to update spend for any entity type (team, org, tag, etc).
|
||||
|
||||
Args:
|
||||
entity_name: Name of entity for logging (e.g., "Team", "Org", "Tag")
|
||||
transactions: Dictionary of {entity_id: response_cost}
|
||||
table_accessor: Prisma table accessor (e.g., prisma_client.db.litellm_teamtable)
|
||||
where_field: Field name for where clause (e.g., "team_id", "organization_id", "tag_name")
|
||||
n_retry_times: Number of retries on failure
|
||||
prisma_client: Prisma client instance
|
||||
proxy_logging_obj: Proxy logging object
|
||||
"""
|
||||
from litellm.proxy.utils import _raise_failed_update_spend_exception
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"{entity_name} Spend transactions: {transactions}"
|
||||
)
|
||||
if transactions is not None and len(transactions.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with prisma_client.db.tx(
|
||||
timeout=timedelta(seconds=60)
|
||||
) as transaction:
|
||||
async with transaction.batch_() as batcher:
|
||||
for entity_id, response_cost in transactions.items():
|
||||
verbose_proxy_logger.debug(
|
||||
f"Updating spend for {entity_name} {where_field}={entity_id} by {response_cost}"
|
||||
)
|
||||
getattr(batcher, table_accessor).update_many(
|
||||
where={where_field: entity_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
break
|
||||
except DB_CONNECTION_ERROR_TYPES as e:
|
||||
if i >= n_retry_times:
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e,
|
||||
start_time=start_time,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
await asyncio.sleep(2**i) # Exponential backoff
|
||||
except Exception as e:
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
|
||||
@@ -380,6 +380,7 @@ class RedisUpdateBuffer:
|
||||
team_list_transactions={},
|
||||
team_member_list_transactions={},
|
||||
org_list_transactions={},
|
||||
tag_list_transactions={},
|
||||
)
|
||||
|
||||
# Define the transaction fields to process
|
||||
@@ -390,6 +391,7 @@ class RedisUpdateBuffer:
|
||||
"team_list_transactions",
|
||||
"team_member_list_transactions",
|
||||
"org_list_transactions",
|
||||
"tag_list_transactions",
|
||||
]
|
||||
|
||||
# Loop through each transaction and combine the values
|
||||
|
||||
@@ -135,6 +135,7 @@ class SpendUpdateQueue(BaseUpdateQueue):
|
||||
team_list_transactions={},
|
||||
team_member_list_transactions={},
|
||||
org_list_transactions={},
|
||||
tag_list_transactions={},
|
||||
)
|
||||
|
||||
# Map entity types to their corresponding transaction dictionary keys
|
||||
@@ -145,6 +146,7 @@ class SpendUpdateQueue(BaseUpdateQueue):
|
||||
Litellm_EntityType.TEAM: "team_list_transactions",
|
||||
Litellm_EntityType.TEAM_MEMBER: "team_member_list_transactions",
|
||||
Litellm_EntityType.ORGANIZATION: "org_list_transactions",
|
||||
Litellm_EntityType.TAG: "tag_list_transactions",
|
||||
}
|
||||
|
||||
for update in updates:
|
||||
@@ -192,6 +194,10 @@ class SpendUpdateQueue(BaseUpdateQueue):
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"org_list_transactions"
|
||||
]
|
||||
elif dict_key == "tag_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"tag_list_transactions"
|
||||
]
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, Union, cast
|
||||
from typing import Any, List, Optional, Union, cast
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
@@ -131,6 +131,11 @@ class _ProxyDBLogger(CustomLogger):
|
||||
if sl_object is not None
|
||||
else kwargs.get("response_cost", None)
|
||||
)
|
||||
tags: Optional[List[str]] = (
|
||||
sl_object.get("request_tags", None)
|
||||
if sl_object is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if response_cost is not None:
|
||||
user_api_key = metadata.get("user_api_key", None)
|
||||
@@ -172,6 +177,7 @@ class _ProxyDBLogger(CustomLogger):
|
||||
response_cost=response_cost,
|
||||
team_id=team_id,
|
||||
parent_otel_span=parent_otel_span,
|
||||
tags=tags,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -11,12 +11,12 @@ Endpoints for /organization operations
|
||||
|
||||
#### ORGANIZATION MANAGEMENT ####
|
||||
|
||||
from litellm._uuid import uuid
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.auth_checks import can_user_call_model
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
@@ -29,6 +29,7 @@ from litellm.proxy.management_helpers.object_permission_utils import (
|
||||
)
|
||||
from litellm.proxy.management_helpers.utils import (
|
||||
get_new_internal_user_defaults,
|
||||
handle_budget_for_entity,
|
||||
management_endpoint_wrapper,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
@@ -168,30 +169,17 @@ async def new_organization(
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if data.budget_id is None:
|
||||
"""
|
||||
Every organization needs a budget attached.
|
||||
|
||||
If none provided, create one based on provided values
|
||||
"""
|
||||
budget_params = LiteLLM_BudgetTable.model_fields.keys()
|
||||
|
||||
# Only include Budget Params when creating an entry in litellm_budgettable
|
||||
_json_data = data.json(exclude_none=True)
|
||||
_budget_data = {k: v for k, v in _json_data.items() if k in budget_params}
|
||||
budget_row = LiteLLM_BudgetTable(**_budget_data)
|
||||
|
||||
new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True))
|
||||
|
||||
_budget = await prisma_client.db.litellm_budgettable.create(
|
||||
data={
|
||||
**new_budget, # type: ignore
|
||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
}
|
||||
) # type: ignore
|
||||
|
||||
data.budget_id = _budget.budget_id
|
||||
# Handle budget creation/assignment using common helper
|
||||
budget_id = await handle_budget_for_entity(
|
||||
data=data,
|
||||
existing_budget_id=None,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
litellm_proxy_admin_name=litellm_proxy_admin_name,
|
||||
)
|
||||
|
||||
if budget_id is not None:
|
||||
data.budget_id = budget_id
|
||||
|
||||
## Handle Object Permission - MCP, Vector Stores etc.
|
||||
object_permission_id = await _set_object_permission(
|
||||
@@ -322,20 +310,22 @@ async def update_organization(
|
||||
existing_organization_row=existing_organization_row,
|
||||
)
|
||||
|
||||
# Handle budget updates if budget fields are provided
|
||||
budget_fields = {k: v for k, v in data.model_dump().items()
|
||||
if k in LiteLLM_BudgetTable.model_fields.keys() and v is not None}
|
||||
from litellm.proxy.proxy_server import litellm_proxy_admin_name
|
||||
|
||||
# Handle budget updates using common helper
|
||||
budget_id = await handle_budget_for_entity(
|
||||
data=data,
|
||||
existing_budget_id=existing_organization_row.budget_id,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
litellm_proxy_admin_name=litellm_proxy_admin_name,
|
||||
)
|
||||
|
||||
if budget_fields and existing_organization_row.budget_id:
|
||||
await update_budget(
|
||||
budget_obj=BudgetNewRequest(
|
||||
budget_id=existing_organization_row.budget_id,
|
||||
**budget_fields
|
||||
),
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
# Update budget_id if it changed
|
||||
if budget_id != existing_organization_row.budget_id:
|
||||
updated_organization_row["budget_id"] = budget_id
|
||||
|
||||
# Remove budget fields from organization update data
|
||||
# Remove budget fields from organization update data (they're handled via budget table)
|
||||
for field in LiteLLM_BudgetTable.model_fields.keys():
|
||||
updated_organization_row.pop(field, None)
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ All /tag management endpoints
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
@@ -25,6 +24,7 @@ from litellm.proxy.management_endpoints.common_daily_activity import (
|
||||
SpendAnalyticsPaginatedResponse,
|
||||
get_daily_activity,
|
||||
)
|
||||
from litellm.proxy.management_helpers.utils import handle_budget_for_entity
|
||||
from litellm.types.tag_management import (
|
||||
LiteLLM_DailyTagSpendTable,
|
||||
TagConfig,
|
||||
@@ -53,69 +53,6 @@ async def _get_model_names(prisma_client, model_ids: list) -> Dict[str, str]:
|
||||
return {}
|
||||
|
||||
|
||||
async def _get_tags_config(prisma_client) -> Dict[str, TagConfig]:
|
||||
"""Helper function to get tags config from db"""
|
||||
try:
|
||||
tags_config = await prisma_client.db.litellm_config.find_unique(
|
||||
where={"param_name": "tags_config"}
|
||||
)
|
||||
if tags_config is None:
|
||||
return {}
|
||||
# Convert from JSON if needed
|
||||
if isinstance(tags_config.param_value, str):
|
||||
config_dict = json.loads(tags_config.param_value)
|
||||
else:
|
||||
config_dict = tags_config.param_value or {}
|
||||
|
||||
# For each tag, get the model names
|
||||
for tag_name, tag_config in config_dict.items():
|
||||
if isinstance(tag_config, dict) and tag_config.get("models"):
|
||||
model_info = await _get_model_names(prisma_client, tag_config["models"])
|
||||
tag_config["model_info"] = model_info
|
||||
|
||||
return config_dict
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
async def _save_tags_config(prisma_client, tags_config: Dict[str, TagConfig]):
|
||||
"""Helper function to save tags config to db"""
|
||||
try:
|
||||
verbose_proxy_logger.debug(f"Saving tags config: {tags_config}")
|
||||
# Convert TagConfig objects to dictionaries
|
||||
tags_config_dict = {}
|
||||
for name, tag in tags_config.items():
|
||||
if isinstance(tag, TagConfig):
|
||||
tag_dict = tag.model_dump()
|
||||
# Remove model_info before saving as it will be dynamically generated
|
||||
if "model_info" in tag_dict:
|
||||
del tag_dict["model_info"]
|
||||
tags_config_dict[name] = tag_dict
|
||||
else:
|
||||
# If it's already a dict, remove model_info
|
||||
tag_copy = tag.copy()
|
||||
if "model_info" in tag_copy:
|
||||
del tag_copy["model_info"]
|
||||
tags_config_dict[name] = tag_copy
|
||||
|
||||
json_tags_config = json.dumps(tags_config_dict, default=str)
|
||||
verbose_proxy_logger.debug(f"JSON tags config: {json_tags_config}")
|
||||
await prisma_client.db.litellm_config.upsert(
|
||||
where={"param_name": "tags_config"},
|
||||
data={
|
||||
"create": {
|
||||
"param_name": "tags_config",
|
||||
"param_value": json_tags_config,
|
||||
},
|
||||
"update": {"param_value": json_tags_config},
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error saving tags config: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
async def get_deployments_by_model(
|
||||
model: str, llm_router: "Router"
|
||||
) -> List["Deployment"]:
|
||||
@@ -159,9 +96,23 @@ async def new_tag(
|
||||
- name: str - The name of the tag
|
||||
- description: Optional[str] - Description of what this tag represents
|
||||
- models: List[str] - List of either 'model_id' or 'model_name' allowed for this tag
|
||||
- budget_id: Optional[str] - The id for a budget (tpm/rpm/max budget) for the tag
|
||||
|
||||
### IF NO BUDGET ID - CREATE ONE WITH THESE PARAMS ###
|
||||
- max_budget: Optional[float] - Max budget for tag
|
||||
- tpm_limit: Optional[int] - Max tpm limit for tag
|
||||
- rpm_limit: Optional[int] - Max rpm limit for tag
|
||||
- max_parallel_requests: Optional[int] - Max parallel requests for tag
|
||||
- soft_budget: Optional[float] - Get a slack alert when this soft budget is reached
|
||||
- model_max_budget: Optional[dict] - Max budget for a specific model
|
||||
- budget_duration: Optional[str] - Frequency of resetting tag budget
|
||||
"""
|
||||
from litellm.proxy._types import CommonProxyErrors
|
||||
from litellm.proxy.proxy_server import llm_router, prisma_client
|
||||
from litellm.proxy.proxy_server import (
|
||||
litellm_proxy_admin_name,
|
||||
llm_router,
|
||||
prisma_client,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
@@ -172,29 +123,38 @@ async def new_tag(
|
||||
status_code=500, detail=CommonProxyErrors.no_llm_router.value
|
||||
)
|
||||
try:
|
||||
# Get existing tags config
|
||||
tags_config = await _get_tags_config(prisma_client)
|
||||
|
||||
# Check if tag already exists
|
||||
if tag.name in tags_config:
|
||||
existing_tag = await prisma_client.db.litellm_tagtable.find_unique(
|
||||
where={"tag_name": tag.name}
|
||||
)
|
||||
if existing_tag is not None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Tag {tag.name} already exists"
|
||||
)
|
||||
|
||||
# Add new tag
|
||||
tags_config[tag.name] = TagConfig(
|
||||
name=tag.name,
|
||||
description=tag.description,
|
||||
models=tag.models,
|
||||
created_at=str(datetime.datetime.now()),
|
||||
updated_at=str(datetime.datetime.now()),
|
||||
created_by=user_api_key_dict.user_id,
|
||||
# Handle budget creation/assignment using common helper
|
||||
budget_id = await handle_budget_for_entity(
|
||||
data=tag,
|
||||
existing_budget_id=None,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
litellm_proxy_admin_name=litellm_proxy_admin_name,
|
||||
)
|
||||
|
||||
# Save updated config
|
||||
await _save_tags_config(
|
||||
prisma_client=prisma_client,
|
||||
tags_config=tags_config,
|
||||
# Get model names for model_info
|
||||
model_info = await _get_model_names(prisma_client, tag.models or [])
|
||||
|
||||
# Create new tag in database
|
||||
new_tag_record = await prisma_client.db.litellm_tagtable.create(
|
||||
data={
|
||||
"tag_name": tag.name,
|
||||
"description": tag.description,
|
||||
"models": tag.models or [],
|
||||
"model_info": json.dumps(model_info),
|
||||
"spend": 0.0,
|
||||
"budget_id": budget_id,
|
||||
"created_by": user_api_key_dict.user_id,
|
||||
}
|
||||
)
|
||||
|
||||
# Update models with new tag
|
||||
@@ -213,13 +173,20 @@ async def new_tag(
|
||||
)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Get model names for response
|
||||
model_info = await _get_model_names(prisma_client, tag.models or [])
|
||||
tags_config[tag.name].model_info = model_info
|
||||
# Build response
|
||||
tag_config = TagConfig(
|
||||
name=new_tag_record.tag_name,
|
||||
description=new_tag_record.description,
|
||||
models=new_tag_record.models,
|
||||
model_info=model_info,
|
||||
created_at=new_tag_record.created_at.isoformat(),
|
||||
updated_at=new_tag_record.updated_at.isoformat(),
|
||||
created_by=new_tag_record.created_by,
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Tag {tag.name} created successfully",
|
||||
"tag": tags_config[tag.name],
|
||||
"tag": tag_config,
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error creating tag: {str(e)}")
|
||||
@@ -264,6 +231,16 @@ async def update_tag(
|
||||
- name: str - The name of the tag to update
|
||||
- description: Optional[str] - Updated description
|
||||
- models: List[str] - Updated list of allowed LLM models
|
||||
- budget_id: Optional[str] - The id for a budget to associate with the tag
|
||||
|
||||
### BUDGET UPDATE PARAMS ###
|
||||
- max_budget: Optional[float] - Max budget for tag
|
||||
- tpm_limit: Optional[int] - Max tpm limit for tag
|
||||
- rpm_limit: Optional[int] - Max rpm limit for tag
|
||||
- max_parallel_requests: Optional[int] - Max parallel requests for tag
|
||||
- soft_budget: Optional[float] - Get a slack alert when this soft budget is reached
|
||||
- model_max_budget: Optional[dict] - Max budget for a specific model
|
||||
- budget_duration: Optional[str] - Frequency of resetting tag budget
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
@@ -271,35 +248,58 @@ async def update_tag(
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
# Get existing tags config
|
||||
tags_config = await _get_tags_config(prisma_client)
|
||||
|
||||
# Check if tag exists
|
||||
if tag.name not in tags_config:
|
||||
existing_tag = await prisma_client.db.litellm_tagtable.find_unique(
|
||||
where={"tag_name": tag.name}
|
||||
)
|
||||
if existing_tag is None:
|
||||
raise HTTPException(status_code=404, detail=f"Tag {tag.name} not found")
|
||||
|
||||
# Update tag
|
||||
tag_config_dict = dict(tags_config[tag.name])
|
||||
tag_config_dict.update(
|
||||
{
|
||||
"description": tag.description,
|
||||
"models": tag.models,
|
||||
"updated_at": str(datetime.datetime.now()),
|
||||
"updated_by": user_api_key_dict.user_id,
|
||||
}
|
||||
from litellm.proxy.proxy_server import litellm_proxy_admin_name
|
||||
|
||||
# Handle budget updates using common helper
|
||||
budget_id = await handle_budget_for_entity(
|
||||
data=tag,
|
||||
existing_budget_id=existing_tag.budget_id,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
litellm_proxy_admin_name=litellm_proxy_admin_name,
|
||||
)
|
||||
tags_config[tag.name] = TagConfig(**tag_config_dict)
|
||||
|
||||
# Save updated config
|
||||
await _save_tags_config(prisma_client, tags_config)
|
||||
|
||||
# Get model names for response
|
||||
# Get model names for model_info
|
||||
model_info = await _get_model_names(prisma_client, tag.models or [])
|
||||
tags_config[tag.name].model_info = model_info
|
||||
|
||||
# Prepare update data
|
||||
update_data = {
|
||||
"description": tag.description,
|
||||
"models": tag.models or [],
|
||||
"model_info": json.dumps(model_info),
|
||||
}
|
||||
|
||||
# Add budget_id if it changed
|
||||
if budget_id != existing_tag.budget_id:
|
||||
update_data["budget_id"] = budget_id
|
||||
|
||||
# Update tag in database
|
||||
updated_tag_record = await prisma_client.db.litellm_tagtable.update(
|
||||
where={"tag_name": tag.name},
|
||||
data=update_data,
|
||||
)
|
||||
|
||||
# Build response
|
||||
tag_config = TagConfig(
|
||||
name=updated_tag_record.tag_name,
|
||||
description=updated_tag_record.description,
|
||||
models=updated_tag_record.models,
|
||||
model_info=model_info,
|
||||
created_at=updated_tag_record.created_at.isoformat(),
|
||||
updated_at=updated_tag_record.updated_at.isoformat(),
|
||||
created_by=updated_tag_record.created_by,
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Tag {tag.name} updated successfully",
|
||||
"tag": tags_config[tag.name],
|
||||
"tag": tag_config,
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error updating tag: {str(e)}")
|
||||
@@ -327,18 +327,47 @@ async def info_tag(
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
tags_config = await _get_tags_config(prisma_client)
|
||||
|
||||
# Filter tags based on requested names
|
||||
requested_tags = {name: tags_config.get(name) for name in data.names}
|
||||
# Query tags from database with budget info
|
||||
tag_records = await prisma_client.db.litellm_tagtable.find_many(
|
||||
where={"tag_name": {"in": data.names}},
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
|
||||
# Check if any requested tags don't exist
|
||||
missing_tags = [name for name in data.names if name not in tags_config]
|
||||
found_tag_names = {tag.tag_name for tag in tag_records}
|
||||
missing_tags = [name for name in data.names if name not in found_tag_names]
|
||||
if missing_tags:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Tags not found: {missing_tags}"
|
||||
)
|
||||
|
||||
# Build response
|
||||
requested_tags = {}
|
||||
for tag_record in tag_records:
|
||||
# Parse model_info from JSON
|
||||
model_info = {}
|
||||
if tag_record.model_info:
|
||||
if isinstance(tag_record.model_info, str):
|
||||
model_info = json.loads(tag_record.model_info)
|
||||
else:
|
||||
model_info = tag_record.model_info
|
||||
|
||||
tag_dict = {
|
||||
"name": tag_record.tag_name,
|
||||
"description": tag_record.description,
|
||||
"models": tag_record.models,
|
||||
"model_info": model_info,
|
||||
"created_at": tag_record.created_at.isoformat(),
|
||||
"updated_at": tag_record.updated_at.isoformat(),
|
||||
"created_by": tag_record.created_by,
|
||||
}
|
||||
|
||||
# Add budget info if available
|
||||
if hasattr(tag_record, "litellm_budget_table") and tag_record.litellm_budget_table:
|
||||
tag_dict["litellm_budget_table"] = tag_record.litellm_budget_table
|
||||
|
||||
requested_tags[tag_record.tag_name] = tag_dict
|
||||
|
||||
return requested_tags
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -348,13 +377,12 @@ async def info_tag(
|
||||
"/tag/list",
|
||||
tags=["tag management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=List[TagConfig],
|
||||
)
|
||||
async def list_tags(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
List all available tags.
|
||||
List all available tags with their budget information.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
@@ -363,8 +391,37 @@ async def list_tags(
|
||||
|
||||
try:
|
||||
## QUERY STORED TAGS ##
|
||||
tags_config = await _get_tags_config(prisma_client)
|
||||
list_of_tags = list(tags_config.values())
|
||||
tag_records = await prisma_client.db.litellm_tagtable.find_many(
|
||||
include={"litellm_budget_table": True}
|
||||
)
|
||||
|
||||
stored_tag_names = set()
|
||||
list_of_tags = []
|
||||
for tag_record in tag_records:
|
||||
stored_tag_names.add(tag_record.tag_name)
|
||||
# Parse model_info from JSON
|
||||
model_info = {}
|
||||
if tag_record.model_info:
|
||||
if isinstance(tag_record.model_info, str):
|
||||
model_info = json.loads(tag_record.model_info)
|
||||
else:
|
||||
model_info = tag_record.model_info
|
||||
|
||||
tag_dict = {
|
||||
"name": tag_record.tag_name,
|
||||
"description": tag_record.description,
|
||||
"models": tag_record.models,
|
||||
"model_info": model_info,
|
||||
"created_at": tag_record.created_at.isoformat(),
|
||||
"updated_at": tag_record.updated_at.isoformat(),
|
||||
"created_by": tag_record.created_by,
|
||||
}
|
||||
|
||||
# Add budget info if available
|
||||
if hasattr(tag_record, "litellm_budget_table") and tag_record.litellm_budget_table:
|
||||
tag_dict["litellm_budget_table"] = tag_record.litellm_budget_table
|
||||
|
||||
list_of_tags.append(tag_dict)
|
||||
|
||||
## QUERY DYNAMIC TAGS ##
|
||||
dynamic_tags = await prisma_client.db.litellm_dailytagspend.find_many(
|
||||
@@ -377,15 +434,15 @@ async def list_tags(
|
||||
]
|
||||
|
||||
dynamic_tag_config = [
|
||||
TagConfig(
|
||||
name=tag.tag,
|
||||
description="This is just a spend tag that was passed dynamically in a request. It does not control any LLM models.",
|
||||
models=None,
|
||||
created_at=tag.created_at.isoformat(),
|
||||
updated_at=tag.updated_at.isoformat(),
|
||||
)
|
||||
{
|
||||
"name": tag.tag,
|
||||
"description": "This is just a spend tag that was passed dynamically in a request. It does not control any LLM models.",
|
||||
"models": None,
|
||||
"created_at": tag.created_at.isoformat(),
|
||||
"updated_at": tag.updated_at.isoformat(),
|
||||
}
|
||||
for tag in dynamic_tags_list
|
||||
if tag.tag not in tags_config
|
||||
if tag.tag not in stored_tag_names
|
||||
]
|
||||
|
||||
return list_of_tags + dynamic_tag_config
|
||||
@@ -414,18 +471,15 @@ async def delete_tag(
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
# Get existing tags config
|
||||
tags_config = await _get_tags_config(prisma_client)
|
||||
|
||||
# Check if tag exists
|
||||
if data.name not in tags_config:
|
||||
existing_tag = await prisma_client.db.litellm_tagtable.find_unique(
|
||||
where={"tag_name": data.name}
|
||||
)
|
||||
if existing_tag is None:
|
||||
raise HTTPException(status_code=404, detail=f"Tag {data.name} not found")
|
||||
|
||||
# Delete tag
|
||||
del tags_config[data.name]
|
||||
|
||||
# Save updated config
|
||||
await _save_tags_config(prisma_client, tags_config)
|
||||
# Delete tag from database
|
||||
await prisma_client.db.litellm_tagtable.delete(where={"tag_name": data.name})
|
||||
|
||||
return {"message": f"Tag {data.name} deleted successfully"}
|
||||
except Exception as e:
|
||||
|
||||
@@ -10,10 +10,12 @@ import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.proxy._types import ( # key request types; user request types; team request types; customer request types
|
||||
BudgetNewRequest,
|
||||
DeleteCustomerRequest,
|
||||
DeleteTeamRequest,
|
||||
DeleteUserRequest,
|
||||
KeyRequest,
|
||||
LiteLLM_BudgetTable,
|
||||
LiteLLM_TeamMembership,
|
||||
LiteLLM_UserTable,
|
||||
ManagementEndpointLoggingPayload,
|
||||
@@ -53,6 +55,89 @@ def get_new_internal_user_defaults(
|
||||
return non_null_dict
|
||||
|
||||
|
||||
async def handle_budget_for_entity(
|
||||
data,
|
||||
existing_budget_id: Optional[str],
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
prisma_client: PrismaClient,
|
||||
litellm_proxy_admin_name: str,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Common helper to handle budget creation/updates for entities (organizations, tags, etc).
|
||||
|
||||
This function:
|
||||
1. Creates a new budget if budget_id is None but budget fields are provided
|
||||
2. Updates an existing budget if budget fields are provided and budget_id exists
|
||||
3. Returns the budget_id to use (existing or newly created)
|
||||
|
||||
Args:
|
||||
data: The request object (e.g., TagNewRequest, NewOrganizationRequest, etc.) containing budget fields
|
||||
existing_budget_id: The existing budget_id if updating an entity, None if creating new
|
||||
user_api_key_dict: User authentication info
|
||||
prisma_client: Database client
|
||||
litellm_proxy_admin_name: Admin name for audit trail
|
||||
|
||||
Returns:
|
||||
Optional[str]: The budget_id to use, or None if no budget was created/updated
|
||||
"""
|
||||
from litellm.proxy.management_endpoints.budget_management_endpoints import (
|
||||
update_budget,
|
||||
)
|
||||
|
||||
# Get all budget field names
|
||||
budget_params = LiteLLM_BudgetTable.model_fields.keys()
|
||||
|
||||
# Extract budget fields from data
|
||||
_json_data = data.model_dump(exclude_none=True) if hasattr(data, "model_dump") else data
|
||||
_budget_data = {k: v for k, v in _json_data.items() if k in budget_params}
|
||||
|
||||
# Check if budget_id is explicitly provided in the data
|
||||
data_budget_id = getattr(data, "budget_id", None)
|
||||
|
||||
# Case 1: Creating new entity - no existing budget_id
|
||||
if existing_budget_id is None:
|
||||
if data_budget_id is not None:
|
||||
# Use the provided budget_id
|
||||
return data_budget_id
|
||||
elif _budget_data:
|
||||
# Create a new budget with the provided fields
|
||||
budget_row = LiteLLM_BudgetTable(**_budget_data)
|
||||
new_budget_data = prisma_client.jsonify_object(
|
||||
budget_row.model_dump(exclude_none=True)
|
||||
)
|
||||
|
||||
_budget = await prisma_client.db.litellm_budgettable.create(
|
||||
data={
|
||||
**new_budget_data, # type: ignore
|
||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
}
|
||||
) # type: ignore
|
||||
|
||||
return _budget.budget_id
|
||||
else:
|
||||
# No budget fields provided, no budget to create
|
||||
return None
|
||||
|
||||
# Case 2: Updating existing entity - has existing budget_id
|
||||
else:
|
||||
# If budget fields are provided, update the existing budget
|
||||
if _budget_data:
|
||||
await update_budget(
|
||||
budget_obj=BudgetNewRequest(
|
||||
budget_id=existing_budget_id, **_budget_data
|
||||
),
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
# If a different budget_id is explicitly provided, use that instead
|
||||
if data_budget_id is not None and data_budget_id != existing_budget_id:
|
||||
return data_budget_id
|
||||
|
||||
# Otherwise, keep using the existing budget_id
|
||||
return existing_budget_id
|
||||
|
||||
|
||||
async def add_new_member(
|
||||
new_member: Member,
|
||||
max_budget_in_team: Optional[float],
|
||||
|
||||
@@ -33,9 +33,9 @@ from litellm.constants import (
|
||||
AIOHTTP_TTL_DNS_CACHE,
|
||||
BASE_MCP_ROUTE,
|
||||
DEFAULT_MAX_RECURSE_DEPTH,
|
||||
DEFAULT_SLACK_ALERTING_THRESHOLD,
|
||||
DEFAULT_SHARED_HEALTH_CHECK_TTL,
|
||||
DEFAULT_SHARED_HEALTH_CHECK_LOCK_TTL,
|
||||
DEFAULT_SHARED_HEALTH_CHECK_TTL,
|
||||
DEFAULT_SLACK_ALERTING_THRESHOLD,
|
||||
LITELLM_EMBEDDING_PROVIDERS_SUPPORTING_INPUT_ARRAY_OF_TOKENS,
|
||||
LITELLM_SETTINGS_SAFE_DB_OVERRIDES,
|
||||
)
|
||||
@@ -1157,6 +1157,7 @@ async def update_cache( # noqa: PLR0915
|
||||
team_id: Optional[str],
|
||||
response_cost: Optional[float],
|
||||
parent_otel_span: Optional[Span], # type: ignore
|
||||
tags: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Use this to update the cache with new user spend.
|
||||
@@ -1377,6 +1378,50 @@ async def update_cache( # noqa: PLR0915
|
||||
f"An error occurred updating end user cache: {str(e)}"
|
||||
)
|
||||
|
||||
### UPDATE TAG SPEND ###
|
||||
async def _update_tag_cache():
|
||||
"""
|
||||
Update the tag cache with the new spend.
|
||||
"""
|
||||
if tags is None or response_cost is None:
|
||||
return
|
||||
|
||||
try:
|
||||
for tag_name in tags:
|
||||
if not tag_name or not isinstance(tag_name, str):
|
||||
continue
|
||||
|
||||
cache_key = f"tag:{tag_name}"
|
||||
# Fetch the existing tag object from cache
|
||||
existing_tag_obj = await user_api_key_cache.async_get_cache(key=cache_key)
|
||||
if existing_tag_obj is None:
|
||||
# do nothing if tag not in api key cache
|
||||
continue
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"_update_tag_cache: existing spend for tag={tag_name}: {existing_tag_obj}; response_cost: {response_cost}"
|
||||
)
|
||||
|
||||
if isinstance(existing_tag_obj, dict):
|
||||
existing_spend = existing_tag_obj.get("spend", 0) or 0
|
||||
else:
|
||||
existing_spend = getattr(existing_tag_obj, "spend", 0) or 0
|
||||
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
new_spend = existing_spend + response_cost
|
||||
|
||||
# Update the spend column for the given tag
|
||||
if isinstance(existing_tag_obj, dict):
|
||||
existing_tag_obj["spend"] = new_spend
|
||||
values_to_update_in_cache.append((cache_key, existing_tag_obj))
|
||||
else:
|
||||
existing_tag_obj.spend = new_spend
|
||||
values_to_update_in_cache.append((cache_key, existing_tag_obj))
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"An error occurred updating tag cache: {str(e)}"
|
||||
)
|
||||
|
||||
if token is not None and response_cost is not None:
|
||||
await _update_key_cache(token=token, response_cost=response_cost)
|
||||
|
||||
@@ -1389,6 +1434,9 @@ async def update_cache( # noqa: PLR0915
|
||||
if team_id is not None:
|
||||
await _update_team_cache()
|
||||
|
||||
if tags is not None:
|
||||
await _update_tag_cache()
|
||||
|
||||
asyncio.create_task(
|
||||
user_api_key_cache.async_set_cache_pipeline(
|
||||
cache_list=values_to_update_in_cache,
|
||||
@@ -1431,7 +1479,9 @@ async def _run_background_health_check():
|
||||
# Initialize shared health check manager if Redis is available and feature is enabled
|
||||
shared_health_manager = None
|
||||
if use_shared_health_check and redis_usage_cache is not None:
|
||||
from litellm.proxy.health_check_utils.shared_health_check_manager import SharedHealthCheckManager
|
||||
from litellm.proxy.health_check_utils.shared_health_check_manager import (
|
||||
SharedHealthCheckManager,
|
||||
)
|
||||
shared_health_manager = SharedHealthCheckManager(
|
||||
redis_cache=redis_usage_cache,
|
||||
health_check_ttl=DEFAULT_SHARED_HEALTH_CHECK_TTL,
|
||||
@@ -1451,10 +1501,13 @@ async def _run_background_health_check():
|
||||
]
|
||||
|
||||
# Use shared health check if available, otherwise fall back to direct health check
|
||||
# Convert health_check_details to bool for perform_shared_health_check (defaults to True if None)
|
||||
details_bool = health_check_details if health_check_details is not None else True
|
||||
|
||||
if shared_health_manager is not None:
|
||||
try:
|
||||
healthy_endpoints, unhealthy_endpoints = await shared_health_manager.perform_shared_health_check(
|
||||
model_list=_llm_model_list, details=health_check_details
|
||||
model_list=_llm_model_list, details=details_bool
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
|
||||
@@ -25,6 +25,7 @@ model LiteLLM_BudgetTable {
|
||||
organization LiteLLM_OrganizationTable[] // multiple orgs can have the same budget
|
||||
keys LiteLLM_VerificationToken[] // multiple keys can have the same budget
|
||||
end_users LiteLLM_EndUserTable[] // multiple end-users can have the same budget
|
||||
tags LiteLLM_TagTable[] // multiple tags can have the same budget
|
||||
team_membership LiteLLM_TeamMembership[] // budgets of Users within a Team
|
||||
organization_membership LiteLLM_OrganizationMembership[] // budgets of Users within a Organization
|
||||
}
|
||||
@@ -245,6 +246,20 @@ model LiteLLM_EndUserTable {
|
||||
blocked Boolean @default(false)
|
||||
}
|
||||
|
||||
// Track tags with budgets and spend
|
||||
model LiteLLM_TagTable {
|
||||
tag_name String @id
|
||||
description String?
|
||||
models String[]
|
||||
model_info Json? // maps model_id to model_name
|
||||
spend Float @default(0.0)
|
||||
budget_id String?
|
||||
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
|
||||
created_at DateTime @default(now()) @map("created_at")
|
||||
created_by String?
|
||||
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||
}
|
||||
|
||||
// store proxy config.yaml
|
||||
model LiteLLM_Config {
|
||||
param_name String @id
|
||||
|
||||
@@ -18,11 +18,27 @@ class TagConfig(TagBase):
|
||||
|
||||
|
||||
class TagNewRequest(TagBase):
|
||||
pass
|
||||
budget_id: Optional[str] = None
|
||||
# Budget fields - if budget_id is None, create a new budget with these params
|
||||
max_budget: Optional[float] = None
|
||||
soft_budget: Optional[float] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
tpm_limit: Optional[int] = None
|
||||
rpm_limit: Optional[int] = None
|
||||
model_max_budget: Optional[Dict] = None
|
||||
budget_duration: Optional[str] = None
|
||||
|
||||
|
||||
class TagUpdateRequest(TagBase):
|
||||
pass
|
||||
budget_id: Optional[str] = None
|
||||
# Budget fields - if provided, will update the budget
|
||||
max_budget: Optional[float] = None
|
||||
soft_budget: Optional[float] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
tpm_limit: Optional[int] = None
|
||||
rpm_limit: Optional[int] = None
|
||||
model_max_budget: Optional[Dict] = None
|
||||
budget_duration: Optional[str] = None
|
||||
|
||||
|
||||
class TagDeleteRequest(BaseModel):
|
||||
|
||||
@@ -25,6 +25,7 @@ model LiteLLM_BudgetTable {
|
||||
organization LiteLLM_OrganizationTable[] // multiple orgs can have the same budget
|
||||
keys LiteLLM_VerificationToken[] // multiple keys can have the same budget
|
||||
end_users LiteLLM_EndUserTable[] // multiple end-users can have the same budget
|
||||
tags LiteLLM_TagTable[] // multiple tags can have the same budget
|
||||
team_membership LiteLLM_TeamMembership[] // budgets of Users within a Team
|
||||
organization_membership LiteLLM_OrganizationMembership[] // budgets of Users within a Organization
|
||||
}
|
||||
@@ -245,6 +246,20 @@ model LiteLLM_EndUserTable {
|
||||
blocked Boolean @default(false)
|
||||
}
|
||||
|
||||
// Track tags with budgets and spend
|
||||
model LiteLLM_TagTable {
|
||||
tag_name String @id
|
||||
description String?
|
||||
models String[]
|
||||
model_info Json? // maps model_id to model_name
|
||||
spend Float @default(0.0)
|
||||
budget_id String?
|
||||
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
|
||||
created_at DateTime @default(now()) @map("created_at")
|
||||
created_by String?
|
||||
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||
}
|
||||
|
||||
// store proxy config.yaml
|
||||
model LiteLLM_Config {
|
||||
param_name String @id
|
||||
|
||||
@@ -26,9 +26,9 @@ from litellm.proxy._types import (
|
||||
from litellm.proxy.auth.auth_checks import (
|
||||
ExperimentalUIJWTToken,
|
||||
_can_object_call_vector_stores,
|
||||
_get_team_db_check,
|
||||
get_user_object,
|
||||
vector_store_access_check,
|
||||
_get_team_db_check,
|
||||
)
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import decrypt_value_helper
|
||||
from litellm.utils import get_utc_datetime
|
||||
@@ -567,3 +567,141 @@ def test_can_object_call_model_no_access_to_alias_or_underlying():
|
||||
assert exc_info.value.type == ProxyErrorTypes.key_model_access_denied
|
||||
assert "key not allowed to access model" in str(exc_info.value.message)
|
||||
assert "my-fake-gpt" in str(exc_info.value.message)
|
||||
|
||||
|
||||
# Tag Budget Enforcement Tests
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tag_objects_batch():
|
||||
"""
|
||||
Test batch fetching of tags validates:
|
||||
- Cached tags are fetched from cache (no DB call for them)
|
||||
- Uncached tags are fetched in ONE batch DB query
|
||||
- After fetching, uncached tags are cached
|
||||
"""
|
||||
from litellm.proxy._types import LiteLLM_TagTable
|
||||
from litellm.proxy.auth.auth_checks import get_tag_objects_batch
|
||||
|
||||
mock_prisma = MagicMock()
|
||||
mock_cache = MagicMock()
|
||||
mock_proxy_logging = MagicMock()
|
||||
|
||||
# Simulate 5 tags: 2 cached, 3 uncached
|
||||
tag_names = ["cached-1", "uncached-1", "cached-2", "uncached-2", "uncached-3"]
|
||||
|
||||
# Mock cached tags
|
||||
cached_tag_1 = {
|
||||
"tag_name": "cached-1",
|
||||
"spend": 10.0,
|
||||
"models": [],
|
||||
"litellm_budget_table": None,
|
||||
}
|
||||
cached_tag_2 = {
|
||||
"tag_name": "cached-2",
|
||||
"spend": 20.0,
|
||||
"models": [],
|
||||
"litellm_budget_table": None,
|
||||
}
|
||||
|
||||
# Mock DB response for uncached tags
|
||||
uncached_tag_1 = MagicMock()
|
||||
uncached_tag_1.tag_name = "uncached-1"
|
||||
uncached_tag_1.spend = 30.0
|
||||
uncached_tag_1.models = []
|
||||
uncached_tag_1.litellm_budget_table = None
|
||||
uncached_tag_1.dict = MagicMock(
|
||||
return_value={
|
||||
"tag_name": "uncached-1",
|
||||
"spend": 30.0,
|
||||
"models": [],
|
||||
"litellm_budget_table": None,
|
||||
}
|
||||
)
|
||||
|
||||
uncached_tag_2 = MagicMock()
|
||||
uncached_tag_2.tag_name = "uncached-2"
|
||||
uncached_tag_2.spend = 40.0
|
||||
uncached_tag_2.models = []
|
||||
uncached_tag_2.litellm_budget_table = None
|
||||
uncached_tag_2.dict = MagicMock(
|
||||
return_value={
|
||||
"tag_name": "uncached-2",
|
||||
"spend": 40.0,
|
||||
"models": [],
|
||||
"litellm_budget_table": None,
|
||||
}
|
||||
)
|
||||
|
||||
uncached_tag_3 = MagicMock()
|
||||
uncached_tag_3.tag_name = "uncached-3"
|
||||
uncached_tag_3.spend = 50.0
|
||||
uncached_tag_3.models = []
|
||||
uncached_tag_3.litellm_budget_table = None
|
||||
uncached_tag_3.dict = MagicMock(
|
||||
return_value={
|
||||
"tag_name": "uncached-3",
|
||||
"spend": 50.0,
|
||||
"models": [],
|
||||
"litellm_budget_table": None,
|
||||
}
|
||||
)
|
||||
|
||||
# Mock cache behavior - return cached tags, None for uncached
|
||||
async def mock_get_cache(key):
|
||||
if key == "tag:cached-1":
|
||||
return cached_tag_1
|
||||
elif key == "tag:cached-2":
|
||||
return cached_tag_2
|
||||
else:
|
||||
return None
|
||||
|
||||
mock_cache.async_get_cache = AsyncMock(side_effect=mock_get_cache)
|
||||
mock_cache.async_set_cache = AsyncMock()
|
||||
|
||||
# Mock DB to return all uncached tags in ONE query
|
||||
mock_prisma.db.litellm_tagtable.find_many = AsyncMock(
|
||||
return_value=[uncached_tag_1, uncached_tag_2, uncached_tag_3]
|
||||
)
|
||||
|
||||
# Call batch fetch
|
||||
tag_objects = await get_tag_objects_batch(
|
||||
tag_names=tag_names,
|
||||
prisma_client=mock_prisma,
|
||||
user_api_key_cache=mock_cache,
|
||||
proxy_logging_obj=mock_proxy_logging,
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert len(tag_objects) == 5
|
||||
assert "cached-1" in tag_objects
|
||||
assert "cached-2" in tag_objects
|
||||
assert "uncached-1" in tag_objects
|
||||
assert "uncached-2" in tag_objects
|
||||
assert "uncached-3" in tag_objects
|
||||
|
||||
# Verify cached tags have correct values
|
||||
assert tag_objects["cached-1"].spend == 10.0
|
||||
assert tag_objects["cached-2"].spend == 20.0
|
||||
|
||||
# Verify uncached tags have correct values
|
||||
assert tag_objects["uncached-1"].spend == 30.0
|
||||
assert tag_objects["uncached-2"].spend == 40.0
|
||||
assert tag_objects["uncached-3"].spend == 50.0
|
||||
|
||||
# Verify DB was called ONCE with all 3 uncached tags
|
||||
mock_prisma.db.litellm_tagtable.find_many.assert_called_once()
|
||||
call_args = mock_prisma.db.litellm_tagtable.find_many.call_args
|
||||
assert call_args.kwargs["where"]["tag_name"]["in"] == [
|
||||
"uncached-1",
|
||||
"uncached-2",
|
||||
"uncached-3",
|
||||
]
|
||||
|
||||
# Verify uncached tags were cached after fetching
|
||||
assert mock_cache.async_set_cache.call_count == 3
|
||||
cache_calls = mock_cache.async_set_cache.call_args_list
|
||||
cached_keys = [call.kwargs["key"] for call in cache_calls]
|
||||
assert "tag:uncached-1" in cached_keys
|
||||
assert "tag:uncached-2" in cached_keys
|
||||
assert "tag:uncached-3" in cached_keys
|
||||
|
||||
@@ -14,13 +14,17 @@ sys.path.insert(
|
||||
|
||||
|
||||
import litellm
|
||||
from litellm.proxy._types import ProxyException
|
||||
from litellm.proxy.common_utils.http_parsing_utils import (
|
||||
_read_request_body,
|
||||
_safe_get_request_headers,
|
||||
_safe_get_request_parsed_body,
|
||||
_safe_get_request_query_params,
|
||||
_safe_set_request_parsed_body,
|
||||
get_form_data,
|
||||
get_request_body,
|
||||
get_tags_from_request_body,
|
||||
)
|
||||
from litellm.proxy._types import ProxyException
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -285,3 +289,96 @@ async def test_get_form_data():
|
||||
# Note: In a real MultiDict, both values would be present
|
||||
# But in our mock dictionary the second value overwrites the first
|
||||
assert "segment" in result["timestamp_granularities"]
|
||||
|
||||
|
||||
def test_get_tags_from_request_body_with_metadata_tags():
|
||||
"""
|
||||
Test that tags are correctly extracted from request body metadata.
|
||||
"""
|
||||
request_body = {
|
||||
"model": "gpt-4",
|
||||
"metadata": {
|
||||
"tags": ["tag1", "tag2", "tag3"]
|
||||
}
|
||||
}
|
||||
|
||||
result = get_tags_from_request_body(request_body=request_body)
|
||||
|
||||
assert result == ["tag1", "tag2", "tag3"]
|
||||
|
||||
|
||||
def test_get_tags_from_request_body_with_litellm_metadata_tags():
|
||||
"""
|
||||
Test that tags are correctly extracted from request body when using litellm_metadata.
|
||||
"""
|
||||
request_body = {
|
||||
"model": "gpt-4",
|
||||
"litellm_metadata": {
|
||||
"tags": ["tag1", "tag2", "tag3"]
|
||||
}
|
||||
}
|
||||
|
||||
result = get_tags_from_request_body(request_body=request_body)
|
||||
|
||||
assert result == ["tag1", "tag2", "tag3"]
|
||||
|
||||
|
||||
def test_get_tags_from_request_body_with_root_tags():
|
||||
"""
|
||||
Test that tags are correctly extracted from root level of request body.
|
||||
"""
|
||||
request_body = {
|
||||
"model": "gpt-4",
|
||||
"tags": ["tag1", "tag2"]
|
||||
}
|
||||
|
||||
result = get_tags_from_request_body(request_body=request_body)
|
||||
|
||||
assert result == ["tag1", "tag2"]
|
||||
|
||||
|
||||
def test_get_tags_from_request_body_with_combined_tags():
|
||||
"""
|
||||
Test that tags from both metadata and root level are combined.
|
||||
"""
|
||||
request_body = {
|
||||
"model": "gpt-4",
|
||||
"metadata": {
|
||||
"tags": ["tag1", "tag2"]
|
||||
},
|
||||
"tags": ["tag3", "tag4"]
|
||||
}
|
||||
|
||||
result = get_tags_from_request_body(request_body=request_body)
|
||||
|
||||
assert result == ["tag1", "tag2", "tag3", "tag4"]
|
||||
|
||||
|
||||
def test_get_tags_from_request_body_filters_non_strings():
|
||||
"""
|
||||
Test that non-string values in tags list are filtered out.
|
||||
"""
|
||||
request_body = {
|
||||
"model": "gpt-4",
|
||||
"metadata": {
|
||||
"tags": ["tag1", 123, "tag2", None, "tag3", {"nested": "dict"}]
|
||||
}
|
||||
}
|
||||
|
||||
result = get_tags_from_request_body(request_body=request_body)
|
||||
|
||||
assert result == ["tag1", "tag2", "tag3"]
|
||||
|
||||
|
||||
def test_get_tags_from_request_body_no_tags():
|
||||
"""
|
||||
Test that empty list is returned when no tags are present.
|
||||
"""
|
||||
request_body = {
|
||||
"model": "gpt-4",
|
||||
"metadata": {}
|
||||
}
|
||||
|
||||
result = get_tags_from_request_body(request_body=request_body)
|
||||
|
||||
assert result == []
|
||||
|
||||
@@ -135,3 +135,127 @@ async def test_update_daily_spend_with_null_entity_id():
|
||||
assert create_data["api_requests"] == 1
|
||||
assert create_data["successful_requests"] == 1
|
||||
assert create_data["failed_requests"] == 0
|
||||
|
||||
|
||||
# Tag Spend Tracking Tests
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_tag_db_with_valid_tags():
|
||||
"""
|
||||
Test that _update_tag_db correctly processes valid tags and adds them to the spend update queue.
|
||||
"""
|
||||
from litellm.proxy._types import Litellm_EntityType, SpendUpdateQueueItem
|
||||
|
||||
writer = DBSpendUpdateWriter()
|
||||
mock_prisma = MagicMock()
|
||||
response_cost = 0.05
|
||||
request_tags = '["prod-tag", "test-tag"]'
|
||||
|
||||
writer.spend_update_queue.add_update = AsyncMock()
|
||||
|
||||
await writer._update_tag_db(
|
||||
response_cost=response_cost,
|
||||
request_tags=request_tags,
|
||||
prisma_client=mock_prisma,
|
||||
)
|
||||
|
||||
assert writer.spend_update_queue.add_update.call_count == 2
|
||||
|
||||
first_call_args = writer.spend_update_queue.add_update.call_args_list[0][1]
|
||||
assert first_call_args["update"]["entity_type"] == Litellm_EntityType.TAG
|
||||
assert first_call_args["update"]["entity_id"] == "prod-tag"
|
||||
assert first_call_args["update"]["response_cost"] == response_cost
|
||||
|
||||
second_call_args = writer.spend_update_queue.add_update.call_args_list[1][1]
|
||||
assert second_call_args["update"]["entity_type"] == Litellm_EntityType.TAG
|
||||
assert second_call_args["update"]["entity_id"] == "test-tag"
|
||||
assert second_call_args["update"]["response_cost"] == response_cost
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_tag_db_with_list_input():
|
||||
"""
|
||||
Test that _update_tag_db correctly handles tags passed as a list instead of JSON string.
|
||||
"""
|
||||
writer = DBSpendUpdateWriter()
|
||||
mock_prisma = MagicMock()
|
||||
response_cost = 0.1
|
||||
request_tags = ["tag1", "tag2", "tag3"]
|
||||
|
||||
writer.spend_update_queue.add_update = AsyncMock()
|
||||
|
||||
await writer._update_tag_db(
|
||||
response_cost=response_cost,
|
||||
request_tags=request_tags,
|
||||
prisma_client=mock_prisma,
|
||||
)
|
||||
|
||||
assert writer.spend_update_queue.add_update.call_count == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_tag_db_with_no_tags():
|
||||
"""
|
||||
Test that _update_tag_db handles None and empty tags gracefully.
|
||||
"""
|
||||
writer = DBSpendUpdateWriter()
|
||||
mock_prisma = MagicMock()
|
||||
response_cost = 0.05
|
||||
|
||||
writer.spend_update_queue.add_update = AsyncMock()
|
||||
|
||||
await writer._update_tag_db(
|
||||
response_cost=response_cost,
|
||||
request_tags=None,
|
||||
prisma_client=mock_prisma,
|
||||
)
|
||||
assert writer.spend_update_queue.add_update.call_count == 0
|
||||
|
||||
await writer._update_tag_db(
|
||||
response_cost=response_cost,
|
||||
request_tags=[],
|
||||
prisma_client=mock_prisma,
|
||||
)
|
||||
assert writer.spend_update_queue.add_update.call_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_tag_db_with_invalid_json():
|
||||
"""
|
||||
Test that _update_tag_db handles invalid JSON gracefully.
|
||||
"""
|
||||
writer = DBSpendUpdateWriter()
|
||||
mock_prisma = MagicMock()
|
||||
response_cost = 0.05
|
||||
request_tags = '{"invalid": json}'
|
||||
|
||||
writer.spend_update_queue.add_update = AsyncMock()
|
||||
|
||||
await writer._update_tag_db(
|
||||
response_cost=response_cost,
|
||||
request_tags=request_tags,
|
||||
prisma_client=mock_prisma,
|
||||
)
|
||||
|
||||
assert writer.spend_update_queue.add_update.call_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_tag_db_without_prisma_client():
|
||||
"""
|
||||
Test that _update_tag_db returns early when prisma_client is None.
|
||||
"""
|
||||
writer = DBSpendUpdateWriter()
|
||||
response_cost = 0.05
|
||||
request_tags = '["tag1"]'
|
||||
|
||||
writer.spend_update_queue.add_update = AsyncMock()
|
||||
|
||||
await writer._update_tag_db(
|
||||
response_cost=response_cost,
|
||||
request_tags=request_tags,
|
||||
prisma_client=None,
|
||||
)
|
||||
|
||||
assert writer.spend_update_queue.add_update.call_count == 0
|
||||
|
||||
@@ -2184,3 +2184,107 @@ def test_should_load_db_object_with_supported_db_objects():
|
||||
)
|
||||
assert proxy_config._should_load_db_object(object_type="prompts") is True
|
||||
assert proxy_config._should_load_db_object(object_type="model_cost_map") is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tag_cache_update_called():
|
||||
"""
|
||||
Test that update_cache updates tag cache when tags are provided.
|
||||
"""
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.proxy.proxy_server import user_api_key_cache
|
||||
|
||||
cache = DualCache()
|
||||
|
||||
setattr(
|
||||
litellm.proxy.proxy_server,
|
||||
"user_api_key_cache",
|
||||
cache,
|
||||
)
|
||||
|
||||
mock_tag_obj = {
|
||||
"tag_name": "test-tag",
|
||||
"spend": 10.0,
|
||||
}
|
||||
|
||||
with patch.object(cache, "async_get_cache", new=AsyncMock(return_value=mock_tag_obj)) as mock_get_cache:
|
||||
with patch.object(cache, "async_set_cache_pipeline", new=AsyncMock()) as mock_set_cache:
|
||||
await litellm.proxy.proxy_server.update_cache(
|
||||
token=None,
|
||||
user_id=None,
|
||||
end_user_id=None,
|
||||
team_id=None,
|
||||
response_cost=5.0,
|
||||
parent_otel_span=None,
|
||||
tags=["test-tag"],
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
mock_get_cache.assert_awaited_once_with(key="tag:test-tag")
|
||||
mock_set_cache.assert_awaited_once()
|
||||
|
||||
call_args = mock_set_cache.call_args
|
||||
cache_list = call_args.kwargs["cache_list"]
|
||||
|
||||
assert len(cache_list) == 1
|
||||
cache_key, cache_value = cache_list[0]
|
||||
assert cache_key == "tag:test-tag"
|
||||
assert cache_value["spend"] == 15.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tag_cache_update_multiple_tags():
|
||||
"""
|
||||
Test that multiple tags are updated in cache.
|
||||
"""
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.proxy.proxy_server import user_api_key_cache
|
||||
|
||||
cache = DualCache()
|
||||
|
||||
setattr(
|
||||
litellm.proxy.proxy_server,
|
||||
"user_api_key_cache",
|
||||
cache,
|
||||
)
|
||||
|
||||
mock_tag1_obj = {"tag_name": "tag1", "spend": 10.0}
|
||||
mock_tag2_obj = {"tag_name": "tag2", "spend": 20.0}
|
||||
|
||||
async def mock_get_cache_side_effect(key):
|
||||
if key == "tag:tag1":
|
||||
return mock_tag1_obj
|
||||
elif key == "tag:tag2":
|
||||
return mock_tag2_obj
|
||||
return None
|
||||
|
||||
with patch.object(cache, "async_get_cache", new=AsyncMock(side_effect=mock_get_cache_side_effect)) as mock_get_cache:
|
||||
with patch.object(cache, "async_set_cache_pipeline", new=AsyncMock()) as mock_set_cache:
|
||||
await litellm.proxy.proxy_server.update_cache(
|
||||
token=None,
|
||||
user_id=None,
|
||||
end_user_id=None,
|
||||
team_id=None,
|
||||
response_cost=5.0,
|
||||
parent_otel_span=None,
|
||||
tags=["tag1", "tag2"],
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert mock_get_cache.call_count == 2
|
||||
mock_set_cache.assert_awaited_once()
|
||||
|
||||
call_args = mock_set_cache.call_args
|
||||
cache_list = call_args.kwargs["cache_list"]
|
||||
|
||||
assert len(cache_list) == 2
|
||||
|
||||
tag_updates = {cache_key: cache_value for cache_key, cache_value in cache_list}
|
||||
assert "tag:tag1" in tag_updates
|
||||
assert "tag:tag2" in tag_updates
|
||||
assert tag_updates["tag:tag1"]["spend"] == 15.0
|
||||
assert tag_updates["tag:tag2"]["spend"] == 25.0
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import CreateTagModal from "./CreateTagModal";
|
||||
|
||||
describe("CreateTagModal", () => {
|
||||
const mockOnCancel = vi.fn();
|
||||
const mockOnSubmit = vi.fn();
|
||||
const mockAvailableModels = [
|
||||
{
|
||||
model_name: "gpt-4",
|
||||
litellm_params: {
|
||||
model: "gpt-4",
|
||||
},
|
||||
model_info: {
|
||||
id: "model-123",
|
||||
},
|
||||
},
|
||||
{
|
||||
model_name: "claude-3",
|
||||
litellm_params: {
|
||||
model: "claude-3",
|
||||
},
|
||||
model_info: {
|
||||
id: "model-456",
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("should submit form with correct values when user creates a tag", async () => {
|
||||
/**
|
||||
* Tests that filling out the form and clicking submit calls onSubmit with the correct values.
|
||||
* This is the core functionality of the CreateTagModal component.
|
||||
*/
|
||||
const user = userEvent.setup();
|
||||
|
||||
render(
|
||||
<CreateTagModal
|
||||
visible={true}
|
||||
onCancel={mockOnCancel}
|
||||
onSubmit={mockOnSubmit}
|
||||
availableModels={mockAvailableModels}
|
||||
/>
|
||||
);
|
||||
|
||||
// Wait for modal to be visible
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Create New Tag")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Fill in the tag name
|
||||
const tagNameInput = screen.getByLabelText("Tag Name");
|
||||
await user.type(tagNameInput, "production-tag");
|
||||
|
||||
// Fill in the description
|
||||
const descriptionTextarea = screen.getByLabelText("Description");
|
||||
await user.type(descriptionTextarea, "Tag for production environment");
|
||||
|
||||
// Submit the form
|
||||
const createButton = screen.getByText("Create Tag");
|
||||
await user.click(createButton);
|
||||
|
||||
// Verify onSubmit was called with correct values
|
||||
await waitFor(() => {
|
||||
expect(mockOnSubmit).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tag_name: "production-tag",
|
||||
description: "Tag for production environment",
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
// Verify onCancel was not called
|
||||
expect(mockOnCancel).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -0,0 +1,153 @@
|
||||
import React from "react";
|
||||
import { Button, TextInput, Accordion, AccordionHeader, AccordionBody, Title } from "@tremor/react";
|
||||
import { Modal, Form, Select as Select2, Tooltip, Input } from "antd";
|
||||
import { InfoCircleOutlined } from "@ant-design/icons";
|
||||
import NumericalInput from "../../shared/numerical_input";
|
||||
import BudgetDurationDropdown from "../../common_components/budget_duration_dropdown";
|
||||
|
||||
interface ModelInfo {
|
||||
model_name: string;
|
||||
litellm_params: {
|
||||
model: string;
|
||||
};
|
||||
model_info: {
|
||||
id: string;
|
||||
};
|
||||
}
|
||||
|
||||
interface CreateTagModalProps {
|
||||
visible: boolean;
|
||||
onCancel: () => void;
|
||||
onSubmit: (values: any) => void;
|
||||
availableModels: ModelInfo[];
|
||||
}
|
||||
|
||||
const CreateTagModal: React.FC<CreateTagModalProps> = ({
|
||||
visible,
|
||||
onCancel,
|
||||
onSubmit,
|
||||
availableModels,
|
||||
}) => {
|
||||
const [form] = Form.useForm();
|
||||
|
||||
const handleFinish = (values: any) => {
|
||||
onSubmit(values);
|
||||
form.resetFields();
|
||||
};
|
||||
|
||||
const handleCancel = () => {
|
||||
form.resetFields();
|
||||
onCancel();
|
||||
};
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title="Create New Tag"
|
||||
visible={visible}
|
||||
width={800}
|
||||
footer={null}
|
||||
onCancel={handleCancel}
|
||||
>
|
||||
<Form
|
||||
form={form}
|
||||
onFinish={handleFinish}
|
||||
labelCol={{ span: 8 }}
|
||||
wrapperCol={{ span: 16 }}
|
||||
labelAlign="left"
|
||||
>
|
||||
<Form.Item
|
||||
label="Tag Name"
|
||||
name="tag_name"
|
||||
rules={[{ required: true, message: "Please input a tag name" }]}
|
||||
>
|
||||
<TextInput />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item label="Description" name="description">
|
||||
<Input.TextArea rows={4} />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
label={
|
||||
<span>
|
||||
Allowed Models{" "}
|
||||
<Tooltip title="Select which LLMs are allowed to process requests from this tag">
|
||||
<InfoCircleOutlined style={{ marginLeft: "4px" }} />
|
||||
</Tooltip>
|
||||
</span>
|
||||
}
|
||||
name="allowed_llms"
|
||||
>
|
||||
<Select2 mode="multiple" placeholder="Select LLMs">
|
||||
{availableModels.map((model) => (
|
||||
<Select2.Option key={model.model_info.id} value={model.model_info.id}>
|
||||
<div>
|
||||
<span>{model.model_name}</span>
|
||||
<span className="text-gray-400 ml-2">({model.model_info.id})</span>
|
||||
</div>
|
||||
</Select2.Option>
|
||||
))}
|
||||
</Select2>
|
||||
</Form.Item>
|
||||
|
||||
<Accordion className="mt-4 mb-4">
|
||||
<AccordionHeader>
|
||||
<Title className="m-0">Budget & Rate Limits (Optional)</Title>
|
||||
</AccordionHeader>
|
||||
<AccordionBody>
|
||||
<Form.Item
|
||||
className="mt-4"
|
||||
label={
|
||||
<span>
|
||||
Max Budget (USD){" "}
|
||||
<Tooltip title="Maximum amount in USD this tag can spend. When reached, requests with this tag will be blocked">
|
||||
<InfoCircleOutlined style={{ marginLeft: "4px" }} />
|
||||
</Tooltip>
|
||||
</span>
|
||||
}
|
||||
name="max_budget"
|
||||
>
|
||||
<NumericalInput step={0.01} precision={2} width={200} />
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
className="mt-4"
|
||||
label={
|
||||
<span>
|
||||
Reset Budget{" "}
|
||||
<Tooltip title="How often the budget should reset. For example, setting 'daily' will reset the budget every 24 hours">
|
||||
<InfoCircleOutlined style={{ marginLeft: "4px" }} />
|
||||
</Tooltip>
|
||||
</span>
|
||||
}
|
||||
name="budget_duration"
|
||||
>
|
||||
<BudgetDurationDropdown onChange={(value) => form.setFieldValue("budget_duration", value)} />
|
||||
</Form.Item>
|
||||
|
||||
<div className="mt-4 p-3 bg-gray-50 rounded-md border border-gray-200">
|
||||
<p className="text-sm text-gray-600">
|
||||
TPM/RPM limits for tags are not currently supported. If you need this feature, please{" "}
|
||||
<a
|
||||
href="https://github.com/BerriAI/litellm/issues/new"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="text-blue-600 hover:text-blue-800 underline"
|
||||
>
|
||||
create a GitHub issue
|
||||
</a>
|
||||
.
|
||||
</p>
|
||||
</div>
|
||||
</AccordionBody>
|
||||
</Accordion>
|
||||
|
||||
<div style={{ textAlign: "right", marginTop: "10px" }}>
|
||||
<Button type="submit">Create Tag</Button>
|
||||
</div>
|
||||
</Form>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
export default CreateTagModal;
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import React, { useState, useEffect } from "react";
|
||||
import { Icon, Button, Col, Text, Grid, TextInput } from "@tremor/react";
|
||||
import { Icon, Button, Col, Text, Grid } from "@tremor/react";
|
||||
import { RefreshIcon } from "@heroicons/react/outline";
|
||||
import { Modal, Form, Select as Select2, Tooltip, Input } from "antd";
|
||||
import { InfoCircleOutlined } from "@ant-design/icons";
|
||||
import TagInfoView from "./tag_info";
|
||||
import { modelInfoCall } from "../networking";
|
||||
import { tagCreateCall, tagListCall, tagDeleteCall } from "../networking";
|
||||
import { Tag } from "./types";
|
||||
import TagTable from "./TagTable";
|
||||
import NotificationsManager from "../molecules/notifications_manager";
|
||||
import CreateTagModal from "./components/CreateTagModal";
|
||||
|
||||
interface ModelInfo {
|
||||
model_name: string;
|
||||
@@ -34,7 +33,6 @@ const TagManagement: React.FC<TagProps> = ({ accessToken, userID, userRole }) =>
|
||||
const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false);
|
||||
const [tagToDelete, setTagToDelete] = useState<string | null>(null);
|
||||
const [lastRefreshed, setLastRefreshed] = useState("");
|
||||
const [form] = Form.useForm();
|
||||
const [availableModels, setAvailableModels] = useState<ModelInfo[]>([]);
|
||||
|
||||
const fetchTags = async () => {
|
||||
@@ -62,10 +60,14 @@ const TagManagement: React.FC<TagProps> = ({ accessToken, userID, userRole }) =>
|
||||
name: formValues.tag_name,
|
||||
description: formValues.description,
|
||||
models: formValues.allowed_llms,
|
||||
max_budget: formValues.max_budget,
|
||||
soft_budget: formValues.soft_budget,
|
||||
tpm_limit: formValues.tpm_limit,
|
||||
rpm_limit: formValues.rpm_limit,
|
||||
budget_duration: formValues.budget_duration,
|
||||
});
|
||||
NotificationsManager.success("Tag created successfully");
|
||||
setIsCreateModalVisible(false);
|
||||
form.resetFields();
|
||||
fetchTags();
|
||||
} catch (error) {
|
||||
console.error("Error creating tag:", error);
|
||||
@@ -173,63 +175,12 @@ const TagManagement: React.FC<TagProps> = ({ accessToken, userID, userRole }) =>
|
||||
</Grid>
|
||||
|
||||
{/* Create Tag Modal */}
|
||||
<Modal
|
||||
title="Create New Tag"
|
||||
<CreateTagModal
|
||||
visible={isCreateModalVisible}
|
||||
width={800}
|
||||
footer={null}
|
||||
onCancel={() => {
|
||||
setIsCreateModalVisible(false);
|
||||
form.resetFields();
|
||||
}}
|
||||
>
|
||||
<Form
|
||||
form={form}
|
||||
onFinish={handleCreate}
|
||||
labelCol={{ span: 8 }}
|
||||
wrapperCol={{ span: 16 }}
|
||||
labelAlign="left"
|
||||
>
|
||||
<Form.Item
|
||||
label="Tag Name"
|
||||
name="tag_name"
|
||||
rules={[{ required: true, message: "Please input a tag name" }]}
|
||||
>
|
||||
<TextInput />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item label="Description" name="description">
|
||||
<Input.TextArea rows={4} />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
label={
|
||||
<span>
|
||||
Allowed Models{" "}
|
||||
<Tooltip title="Select which LLMs are allowed to process requests from this tag">
|
||||
<InfoCircleOutlined style={{ marginLeft: "4px" }} />
|
||||
</Tooltip>
|
||||
</span>
|
||||
}
|
||||
name="allowed_llms"
|
||||
>
|
||||
<Select2 mode="multiple" placeholder="Select LLMs">
|
||||
{availableModels.map((model) => (
|
||||
<Select2.Option key={model.model_info.id} value={model.model_info.id}>
|
||||
<div>
|
||||
<span>{model.model_name}</span>
|
||||
<span className="text-gray-400 ml-2">({model.model_info.id})</span>
|
||||
</div>
|
||||
</Select2.Option>
|
||||
))}
|
||||
</Select2>
|
||||
</Form.Item>
|
||||
|
||||
<div style={{ textAlign: "right", marginTop: "10px" }}>
|
||||
<Button type="submit">Create Tag</Button>
|
||||
</div>
|
||||
</Form>
|
||||
</Modal>
|
||||
onCancel={() => setIsCreateModalVisible(false)}
|
||||
onSubmit={handleCreate}
|
||||
availableModels={availableModels}
|
||||
/>
|
||||
|
||||
{/* Delete Confirmation Modal */}
|
||||
{isDeleteModalOpen && (
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import React, { useState, useEffect } from "react";
|
||||
import { Card, Text, Title, Button, Badge } from "@tremor/react";
|
||||
import { Card, Text, Title, Button, Badge, Accordion, AccordionHeader, AccordionBody, Title as TremorTitle } from "@tremor/react";
|
||||
import { Form, Input, Select as Select2, Tooltip } from "antd";
|
||||
import { InfoCircleOutlined } from "@ant-design/icons";
|
||||
import { fetchUserModels } from "../organisms/create_key_button";
|
||||
@@ -7,6 +7,11 @@ import { getModelDisplayName } from "../key_team_helpers/fetch_available_models_
|
||||
import { tagInfoCall, tagUpdateCall } from "../networking";
|
||||
import { Tag } from "./types";
|
||||
import NotificationsManager from "../molecules/notifications_manager";
|
||||
import NumericalInput from "../shared/numerical_input";
|
||||
import BudgetDurationDropdown from "../common_components/budget_duration_dropdown";
|
||||
import { copyToClipboard as utilCopyToClipboard } from "@/utils/dataUtils";
|
||||
import { CheckIcon, CopyIcon } from "lucide-react";
|
||||
import { Button as AntdButton } from "antd";
|
||||
|
||||
interface TagInfoViewProps {
|
||||
tagId: string;
|
||||
@@ -21,6 +26,17 @@ const TagInfoView: React.FC<TagInfoViewProps> = ({ tagId, onClose, accessToken,
|
||||
const [tagDetails, setTagDetails] = useState<Tag | null>(null);
|
||||
const [isEditing, setIsEditing] = useState<boolean>(editTag);
|
||||
const [userModels, setUserModels] = useState<string[]>([]);
|
||||
const [copiedStates, setCopiedStates] = useState<Record<string, boolean>>({});
|
||||
|
||||
const copyToClipboard = async (text: string | null | undefined, key: string) => {
|
||||
const success = await utilCopyToClipboard(text);
|
||||
if (success) {
|
||||
setCopiedStates((prev) => ({ ...prev, [key]: true }));
|
||||
setTimeout(() => {
|
||||
setCopiedStates((prev) => ({ ...prev, [key]: false }));
|
||||
}, 2000);
|
||||
}
|
||||
};
|
||||
|
||||
const fetchTagDetails = async () => {
|
||||
if (!accessToken) return;
|
||||
@@ -34,6 +50,8 @@ const TagInfoView: React.FC<TagInfoViewProps> = ({ tagId, onClose, accessToken,
|
||||
name: tagData.name,
|
||||
description: tagData.description,
|
||||
models: tagData.models,
|
||||
max_budget: tagData.litellm_budget_table?.max_budget,
|
||||
budget_duration: tagData.litellm_budget_table?.budget_duration,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -62,6 +80,10 @@ const TagInfoView: React.FC<TagInfoViewProps> = ({ tagId, onClose, accessToken,
|
||||
name: values.name,
|
||||
description: values.description,
|
||||
models: values.models,
|
||||
max_budget: values.max_budget,
|
||||
tpm_limit: values.tpm_limit,
|
||||
rpm_limit: values.rpm_limit,
|
||||
budget_duration: values.budget_duration,
|
||||
});
|
||||
NotificationsManager.success("Tag updated successfully");
|
||||
setIsEditing(false);
|
||||
@@ -83,7 +105,23 @@ const TagInfoView: React.FC<TagInfoViewProps> = ({ tagId, onClose, accessToken,
|
||||
<Button onClick={onClose} className="mb-4">
|
||||
← Back to Tags
|
||||
</Button>
|
||||
<Title>Tag Name: {tagDetails.name}</Title>
|
||||
<div className="flex items-center gap-2">
|
||||
<Text className="font-medium">Tag Name:</Text>
|
||||
<span className="font-mono px-2 py-1 bg-gray-100 rounded text-sm border border-gray-200">
|
||||
{tagDetails.name}
|
||||
</span>
|
||||
<AntdButton
|
||||
type="text"
|
||||
size="small"
|
||||
icon={copiedStates["tag-name"] ? <CheckIcon size={12} /> : <CopyIcon size={12} />}
|
||||
onClick={() => copyToClipboard(tagDetails.name, "tag-name")}
|
||||
className={`transition-all duration-200 ${
|
||||
copiedStates["tag-name"]
|
||||
? "text-green-600 bg-green-50 border-green-200"
|
||||
: "text-gray-500 hover:text-gray-700 hover:bg-gray-100"
|
||||
}`}
|
||||
/>
|
||||
</div>
|
||||
<Text className="text-gray-500">{tagDetails.description || "No description"}</Text>
|
||||
</div>
|
||||
{is_admin && !isEditing && <Button onClick={() => setIsEditing(true)}>Edit Tag</Button>}
|
||||
@@ -120,6 +158,56 @@ const TagInfoView: React.FC<TagInfoViewProps> = ({ tagId, onClose, accessToken,
|
||||
</Select2>
|
||||
</Form.Item>
|
||||
|
||||
<Accordion className="mt-4 mb-4">
|
||||
<AccordionHeader>
|
||||
<TremorTitle className="m-0">Budget & Rate Limits</TremorTitle>
|
||||
</AccordionHeader>
|
||||
<AccordionBody>
|
||||
<Form.Item
|
||||
label={
|
||||
<span>
|
||||
Max Budget (USD){" "}
|
||||
<Tooltip title="Maximum amount in USD this tag can spend">
|
||||
<InfoCircleOutlined style={{ marginLeft: "4px" }} />
|
||||
</Tooltip>
|
||||
</span>
|
||||
}
|
||||
name="max_budget"
|
||||
>
|
||||
<NumericalInput step={0.01} precision={2} width={200} />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
label={
|
||||
<span>
|
||||
Reset Budget{" "}
|
||||
<Tooltip title="How often the budget should reset">
|
||||
<InfoCircleOutlined style={{ marginLeft: "4px" }} />
|
||||
</Tooltip>
|
||||
</span>
|
||||
}
|
||||
name="budget_duration"
|
||||
>
|
||||
<BudgetDurationDropdown onChange={(value) => form.setFieldValue("budget_duration", value)} />
|
||||
</Form.Item>
|
||||
|
||||
<div className="mt-4 p-3 bg-gray-50 rounded-md border border-gray-200">
|
||||
<p className="text-sm text-gray-600">
|
||||
TPM/RPM limits for tags are not currently supported. If you need this feature, please{" "}
|
||||
<a
|
||||
href="https://github.com/BerriAI/litellm/issues/new"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="text-blue-600 hover:text-blue-800 underline"
|
||||
>
|
||||
create a GitHub issue
|
||||
</a>
|
||||
.
|
||||
</p>
|
||||
</div>
|
||||
</AccordionBody>
|
||||
</Accordion>
|
||||
|
||||
<div className="flex justify-end space-x-2">
|
||||
<Button onClick={() => setIsEditing(false)}>Cancel</Button>
|
||||
<Button type="submit">Save Changes</Button>
|
||||
@@ -142,7 +230,7 @@ const TagInfoView: React.FC<TagInfoViewProps> = ({ tagId, onClose, accessToken,
|
||||
<div>
|
||||
<Text className="font-medium">Allowed LLMs</Text>
|
||||
<div className="flex flex-wrap gap-2 mt-2">
|
||||
{tagDetails.models.length === 0 ? (
|
||||
{!tagDetails.models || tagDetails.models.length === 0 ? (
|
||||
<Badge color="red">All Models</Badge>
|
||||
) : (
|
||||
tagDetails.models.map((modelId) => (
|
||||
@@ -163,6 +251,38 @@ const TagInfoView: React.FC<TagInfoViewProps> = ({ tagId, onClose, accessToken,
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
{tagDetails.litellm_budget_table && (
|
||||
<Card>
|
||||
<Title>Budget & Rate Limits</Title>
|
||||
<div className="space-y-4 mt-4">
|
||||
{tagDetails.litellm_budget_table.max_budget !== undefined && tagDetails.litellm_budget_table.max_budget !== null && (
|
||||
<div>
|
||||
<Text className="font-medium">Max Budget</Text>
|
||||
<Text>${tagDetails.litellm_budget_table.max_budget}</Text>
|
||||
</div>
|
||||
)}
|
||||
{tagDetails.litellm_budget_table.budget_duration && (
|
||||
<div>
|
||||
<Text className="font-medium">Budget Duration</Text>
|
||||
<Text>{tagDetails.litellm_budget_table.budget_duration}</Text>
|
||||
</div>
|
||||
)}
|
||||
{tagDetails.litellm_budget_table.tpm_limit !== undefined && tagDetails.litellm_budget_table.tpm_limit !== null && (
|
||||
<div>
|
||||
<Text className="font-medium">TPM Limit</Text>
|
||||
<Text>{tagDetails.litellm_budget_table.tpm_limit.toLocaleString()}</Text>
|
||||
</div>
|
||||
)}
|
||||
{tagDetails.litellm_budget_table.rpm_limit !== undefined && tagDetails.litellm_budget_table.rpm_limit !== null && (
|
||||
<div>
|
||||
<Text className="font-medium">RPM Limit</Text>
|
||||
<Text>{tagDetails.litellm_budget_table.rpm_limit.toLocaleString()}</Text>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</Card>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -7,6 +7,15 @@ export interface Tag {
|
||||
updated_at: string;
|
||||
created_by?: string;
|
||||
updated_by?: string;
|
||||
litellm_budget_table?: {
|
||||
max_budget?: number;
|
||||
soft_budget?: number;
|
||||
tpm_limit?: number;
|
||||
rpm_limit?: number;
|
||||
max_parallel_requests?: number;
|
||||
budget_duration?: string;
|
||||
model_max_budget?: any;
|
||||
};
|
||||
}
|
||||
|
||||
export interface TagInfoRequest {
|
||||
@@ -17,12 +26,22 @@ export interface TagNewRequest {
|
||||
name: string;
|
||||
description?: string;
|
||||
models: string[];
|
||||
max_budget?: number;
|
||||
soft_budget?: number;
|
||||
tpm_limit?: number;
|
||||
rpm_limit?: number;
|
||||
budget_duration?: string;
|
||||
}
|
||||
|
||||
export interface TagUpdateRequest {
|
||||
name: string;
|
||||
description?: string;
|
||||
models: string[];
|
||||
max_budget?: number;
|
||||
soft_budget?: number;
|
||||
tpm_limit?: number;
|
||||
rpm_limit?: number;
|
||||
budget_duration?: string;
|
||||
}
|
||||
|
||||
export interface TagDeleteRequest {
|
||||
|
||||
Reference in New Issue
Block a user