[Feat] Tag Management - Add support for setting tag based budgets (#15433)

* feat: add LiteLLM_TagTable

* fix: use new table for tag management

* fix - allow setting budgets for tags

* working tag creation

* fix schema.prisma

* add tag info

* ui fixes

* ui fix tag info

* TAG_CACHE_IN_MEMORY_TTL_SECONDS

* add Litellm_EntityType

* fix get_aggregated_db_spend_update_transactions

* fix: _update_entity_spend_in_db

* fix _tag_max_budget_check

* add tag budget check

* add tag_list_transactions

* test_get_tag_objects_batch

* test_update_tag_db_without_prisma_client

* fix get_tags_from_request_body

* get_tags_from_request_body

* fix get_tags_from_request_body

* fix spend tracking utils

* get_tags_from_request_body

* test_get_tags_from_request_body_with_metadata_tags

* feat: add _update_tag_cache spend tracking

* fix _PROXY_track_cost_callback

* test_tag_cache_update_multiple_tags

* fix tag info

* docs fix

* docs tag budgets

* doc fix

* docs fix

* fix tag budget

* docs tag budgets

* docs fix

* ruff fix
This commit is contained in:
Ishaan Jaff
2025-10-10 19:24:50 -07:00
committed by GitHub
parent 6b66e12dea
commit 527c8f59fa
28 changed files with 1913 additions and 244 deletions
+277
View File
@@ -0,0 +1,277 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Setting Tag Budgets
Track spend and set budgets for your API requests using tags. Tags allow you to categorize and monitor costs across different cost centers, projects, and departments.
## Pre-Requisites
- You must set up a Postgres database (e.g. Supabase, Neon, etc.)
## What are Tags?
Tags are labels you can attach to your LLM requests to track and limit spending by category.
**Common Use Cases:**
- **Cost Center Tracking**: Allocate LLM costs to specific departments or business units (e.g., "engineering", "marketing", "customer-support")
- **Project-based Budgeting**: Set budgets for different projects or initiatives (e.g., "project-alpha", "chatbot-v2")
- **Customer Attribution**: Track spend per customer or client (e.g., "customer-acme", "customer-techcorp")
- **Feature Monitoring**: Monitor costs for specific features (e.g., "feature-chat", "feature-summarization")
Tags are added to each request in the `metadata` field to track and enforce budget limits.
## Setting Tag Budgets
### 1. Create a tag with budget
Create a tag to represent a cost center, project, or any budget category. Set `max_budget` ($ value allowed) and `budget_duration` (how frequently the budget resets).
**Example:** Create a tag for your Engineering department with a monthly $500 budget
#### API
Create a new tag and set `max_budget` and `budget_duration`
```shell
curl -X POST 'http://0.0.0.0:4000/tag/new' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"name": "engineering",
"description": "Engineering department cost center",
"max_budget": 500.0,
"budget_duration": "30d"
}'
```
**Request Body Parameters:**
| Parameter | Type | Required | Description |
|-----------|------|----------|-------------|
| `name` | string | Yes | Unique name for the tag (e.g., cost center name) |
| `description` | string | No | Description of what this tag tracks |
| `models` | list[string] | No | Restrict tag to specific models |
| `max_budget` | float | No | Maximum budget in USD |
| `budget_duration` | string | No | How often budget resets (e.g., "30d", "1d") |
| `soft_budget` | float | No | Soft budget limit for warnings |
**Response:**
```json
{
"name": "engineering",
"description": "Engineering department cost center",
"max_budget": 500.0,
"budget_duration": "30d",
"budget_reset_at": "2025-11-10T00:00:00Z",
"created_at": "2025-10-11T00:00:00Z"
}
```
#### LiteLLM Admin UI
Navigate to the **Tag Management** page and click **Create New Tag**. Fill in the tag details and set your budget:
<Image
img={require('../../img/tag_budget1.png')}
style={{width: '80%', display: 'block', margin: '0'}}
/>
<br />
**Possible values for `budget_duration`:**
| `budget_duration` | When Budget will reset |
| --- | --- |
| `budget_duration="1s"` | every 1 second |
| `budget_duration="1m"` | every 1 minute |
| `budget_duration="1h"` | every 1 hour |
| `budget_duration="1d"` | every 1 day |
| `budget_duration="7d"` | every 1 week |
| `budget_duration="30d"` | every 1 month |
### 2. Use the tag in your requests
Add tags to your API requests in the `metadata` field:
:::info Tags Budgets on API Keys
Currently, tag budget enforcement is only supported per request. If you'd like to set tags on API keys so all requests automatically inherit the tags budgets, please [create a feature request on GitHub](https://github.com/BerriAI/litellm/issues/new?assignees=&labels=enhancement&projects=&template=feature_request.yml&title=%5BFeat%5D%3A).
:::
<Tabs>
<TabItem value="openai" label="OpenAI SDK">
```python
import openai
client = openai.OpenAI(
api_key="sk-1234", # Your LiteLLM proxy key
base_url="http://0.0.0.0:4000"
)
response = client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": "Hello"}],
extra_body={
"metadata": {
"tags": ["engineering"]
}
}
)
```
</TabItem>
<TabItem value="curl" label="cURL">
```shell
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"model": "gpt-4",
"messages": [{"role": "user", "content": "Hello"}],
"metadata": {
"tags": ["engineering"]
}
}'
```
</TabItem>
</Tabs>
### 3. Test It
Make requests until the budget is exceeded:
```shell
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"model": "gpt-4",
"messages": [{"role": "user", "content": "Hello"}],
"metadata": {
"tags": ["engineering"]
}
}'
```
**When budget is exceeded, you'll see:**
```json
{
"error": {
"message": "Budget has been exceeded! Tag=engineering Current cost: 505.50, Max budget: 500.0",
"type": "budget_exceeded",
"param": null,
"code": "400"
}
}
```
## Managing Tags
### View Tag Information
Get information about specific tags:
```shell
curl -X POST 'http://0.0.0.0:4000/tag/info' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"names": ["engineering", "marketing"]
}'
```
**Response:**
```json
{
"engineering": {
"name": "engineering",
"description": "Engineering department cost center",
"spend": 245.50,
"max_budget": 500.0,
"budget_duration": "30d",
"budget_reset_at": "2025-11-10T00:00:00Z",
"created_at": "2025-10-11T00:00:00Z",
"updated_at": "2025-10-11T12:30:00Z"
},
"marketing": {
"name": "marketing",
"description": "Marketing department cost center",
"spend": 89.20,
"max_budget": 300.0,
"budget_duration": "30d",
"budget_reset_at": "2025-11-10T00:00:00Z",
"created_at": "2025-10-11T00:00:00Z",
"updated_at": "2025-10-11T12:30:00Z"
}
}
```
### Update Tag Budget
Update an existing tag's budget:
```shell
curl -X POST 'http://0.0.0.0:4000/tag/update' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"name": "engineering",
"max_budget": 750.0,
"budget_duration": "30d"
}'
```
### Delete Tag
```shell
curl -X POST 'http://0.0.0.0:4000/tag/delete' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"name": "engineering"
}'
```
## Multiple Tags per Request
You can apply multiple tags to a single request to track costs across different dimensions simultaneously. For example, track both the cost center and the specific project:
```python
response = client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": "Hello"}],
extra_body={
"metadata": {
"tags": ["engineering", "project-alpha", "customer-acme"]
}
}
)
```
```shell
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"model": "gpt-4",
"messages": [{"role": "user", "content": "Hello"}],
"metadata": {
"tags": ["engineering", "project-alpha", "customer-acme"]
}
}'
```
**Budget Enforcement:** If any tag exceeds its budget, the request will be rejected.
Binary file not shown.

After

Width:  |  Height:  |  Size: 305 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 416 KiB

+3 -2
View File
@@ -189,12 +189,13 @@ const sidebars = {
type: "category",
label: "Budgets + Rate Limits",
items: [
"proxy/users",
"proxy/team_budgets",
"proxy/tag_budgets",
"proxy/customers",
"proxy/dynamic_rate_limit",
"proxy/rate_limit_tiers",
"proxy/team_budgets",
"proxy/temporary_budget_increase",
"proxy/users"
],
},
"proxy/caching",
@@ -25,6 +25,7 @@ model LiteLLM_BudgetTable {
organization LiteLLM_OrganizationTable[] // multiple orgs can have the same budget
keys LiteLLM_VerificationToken[] // multiple keys can have the same budget
end_users LiteLLM_EndUserTable[] // multiple end-users can have the same budget
tags LiteLLM_TagTable[] // multiple tags can have the same budget
team_membership LiteLLM_TeamMembership[] // budgets of Users within a Team
organization_membership LiteLLM_OrganizationMembership[] // budgets of Users within a Organization
}
@@ -245,6 +246,20 @@ model LiteLLM_EndUserTable {
blocked Boolean @default(false)
}
// Track tags with budgets and spend
model LiteLLM_TagTable {
tag_name String @id
description String?
models String[]
model_info Json? // maps model_id to model_name
spend Float @default(0.0)
budget_id String?
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
created_at DateTime @default(now()) @map("created_at")
created_by String?
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
}
// store proxy config.yaml
model LiteLLM_Config {
param_name String @id
+26
View File
@@ -185,6 +185,7 @@ class Litellm_EntityType(enum.Enum):
TEAM = "team"
TEAM_MEMBER = "team_member"
ORGANIZATION = "organization"
TAG = "tag"
# global proxy level entity
PROXY = "proxy"
@@ -2143,6 +2144,30 @@ class LiteLLM_EndUserTable(LiteLLMPydanticObjectBase):
model_config = ConfigDict(protected_namespaces=())
class LiteLLM_TagTable(LiteLLMPydanticObjectBase):
tag_name: str
description: Optional[str] = None
models: List[str] = []
model_info: Optional[dict] = None
spend: float = 0.0
budget_id: Optional[str] = None
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
created_at: Optional[datetime] = None
created_by: Optional[str] = None
updated_at: Optional[datetime] = None
@model_validator(mode="before")
@classmethod
def set_model_info(cls, values):
if values.get("spend") is None:
values.update({"spend": 0.0})
if values.get("models") is None:
values.update({"models": []})
return values
model_config = ConfigDict(protected_namespaces=())
class LiteLLM_SpendLogs(LiteLLMPydanticObjectBase):
request_id: str
api_key: str
@@ -3419,6 +3444,7 @@ class DBSpendUpdateTransactions(TypedDict):
team_list_transactions: Optional[Dict[str, float]]
team_member_list_transactions: Optional[Dict[str, float]]
org_list_transactions: Optional[Dict[str, float]]
tag_list_transactions: Optional[Dict[str, float]]
class SpendUpdateQueueItem(TypedDict, total=False):
+174
View File
@@ -35,6 +35,7 @@ from litellm.proxy._types import (
LiteLLM_ObjectPermissionTable,
LiteLLM_OrganizationMembershipTable,
LiteLLM_OrganizationTable,
LiteLLM_TagTable,
LiteLLM_TeamMembership,
LiteLLM_TeamTable,
LiteLLM_TeamTableCachedObj,
@@ -99,6 +100,8 @@ async def common_checks(
10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks
11. [OPTIONAL] Vector store checks - is the object allowed to access the vector store
"""
from litellm.proxy.proxy_server import prisma_client, user_api_key_cache
_model: Optional[Union[str, List[str]]] = get_model_from_request(
request_body, route
)
@@ -139,6 +142,14 @@ async def common_checks(
valid_token=valid_token,
)
await _tag_max_budget_check(
request_body=request_body,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
valid_token=valid_token,
)
# 4. If user is in budget
## 4.1 check personal budget, if personal key
if (
@@ -499,6 +510,115 @@ async def get_end_user_object(
return None
@log_db_metrics
async def get_tag_objects_batch(
tag_names: List[str],
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> Dict[str, LiteLLM_TagTable]:
"""
Batch fetch multiple tag objects from cache and db.
Optimizes for latency by:
1. Fetching all cached tags in parallel
2. Batch fetching uncached tags in one DB query
Args:
tag_names: List of tag names to fetch
prisma_client: Prisma database client
user_api_key_cache: Cache for storing tag objects
parent_otel_span: Optional OpenTelemetry span for tracing
proxy_logging_obj: Optional proxy logging object
Returns:
Dictionary mapping tag_name to LiteLLM_TagTable object
"""
if prisma_client is None:
return {}
if not tag_names:
return {}
tag_objects = {}
uncached_tags = []
# Try to get all tags from cache first
for tag_name in tag_names:
cache_key = f"tag:{tag_name}"
cached_tag = await user_api_key_cache.async_get_cache(key=cache_key)
if cached_tag is not None:
if isinstance(cached_tag, dict):
tag_objects[tag_name] = LiteLLM_TagTable(**cached_tag)
else:
tag_objects[tag_name] = cached_tag
else:
uncached_tags.append(tag_name)
# Batch fetch uncached tags from DB in one query
if uncached_tags:
try:
db_tags = await prisma_client.db.litellm_tagtable.find_many(
where={"tag_name": {"in": uncached_tags}},
include={"litellm_budget_table": True},
)
# Cache and add to tag_objects
for db_tag in db_tags:
tag_name = db_tag.tag_name
cache_key = f"tag:{tag_name}"
# Cache with default TTL (same as end_user objects)
await user_api_key_cache.async_set_cache(
key=cache_key, value=db_tag.dict()
)
tag_objects[tag_name] = LiteLLM_TagTable(**db_tag.dict())
except Exception as e:
verbose_proxy_logger.debug(
f"Error batch fetching tags from database: {e}"
)
return tag_objects
@log_db_metrics
async def get_tag_object(
tag_name: Optional[str],
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> Optional[LiteLLM_TagTable]:
"""
Returns tag object from cache or db.
Uses default cache TTL (same as end_user objects) to avoid drift.
Args:
tag_name: Name of the tag to fetch
prisma_client: Prisma database client
user_api_key_cache: Cache for storing tag objects
parent_otel_span: Optional OpenTelemetry span for tracing
proxy_logging_obj: Optional proxy logging object
Returns:
LiteLLM_TagTable object if found, None otherwise
"""
if prisma_client is None or tag_name is None:
return None
# Use batch helper for consistency
tag_objects = await get_tag_objects_batch(
tag_names=[tag_name],
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
return tag_objects.get(tag_name)
@log_db_metrics
async def get_team_membership(
user_id: str,
@@ -1642,6 +1762,60 @@ async def _team_max_budget_check(
)
async def _tag_max_budget_check(
request_body: dict,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
proxy_logging_obj: ProxyLogging,
valid_token: Optional[UserAPIKeyAuth],
):
"""
Check if any tags in the request are over their max budget.
Raises:
BudgetExceededError if any tag is over its max budget.
Triggers a budget alert if any tag is over its max budget.
"""
from litellm.proxy.common_utils.http_parsing_utils import (
get_tags_from_request_body,
)
if prisma_client is None:
return
# Get tags from request metadata
tags = get_tags_from_request_body(request_body=request_body)
if not tags:
return
# Batch fetch all tags in one go
tag_objects = await get_tag_objects_batch(
tag_names=tags,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
# Check budget for each tag
for tag_name in tags:
tag_object = tag_objects.get(tag_name)
if tag_object is None:
continue
# Check if tag has budget limits
if (
tag_object.litellm_budget_table is not None
and tag_object.litellm_budget_table.max_budget is not None
and tag_object.spend is not None
and tag_object.spend > tag_object.litellm_budget_table.max_budget
):
raise litellm.BudgetExceededError(
current_cost=tag_object.spend,
max_budget=tag_object.litellm_budget_table.max_budget,
message=f"Budget has been exceeded! Tag={tag_name} Current cost: {tag_object.spend}, Max budget: {tag_object.litellm_budget_table.max_budget}",
)
def is_model_allowed_by_pattern(model: str, allowed_model_pattern: str) -> bool:
"""
Check if a model matches an allowed pattern.
@@ -7,6 +7,9 @@ from fastapi import Request, UploadFile, status
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import ProxyException
from litellm.proxy.common_utils.callback_utils import (
get_metadata_variable_name_from_kwargs,
)
from litellm.types.router import Deployment
@@ -245,4 +248,23 @@ async def get_request_body(request: Request) -> Dict[str, Any]:
raise ValueError(
f"Unsupported content type: {request.headers.get('content-type')}"
)
return {}
return {}
def get_tags_from_request_body(request_body: dict) -> List[str]:
"""
Extract tags from request body metadata.
Args:
request_body: The request body dictionary
Returns:
List of tag names (strings), empty list if no valid tags found
"""
metadata_variable_name = get_metadata_variable_name_from_kwargs(request_body)
metadata = request_body.get(metadata_variable_name, {})
tags_in_metadata: List[str] = metadata.get("tags", [])
tags_in_request_body: List[str] = request_body.get("tags", [])
combined_tags: List[str] = tags_in_metadata + tags_in_request_body
return [tag for tag in combined_tags if isinstance(tag, str)]
+125
View File
@@ -17,6 +17,7 @@ import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache, RedisCache
from litellm.constants import DB_SPEND_UPDATE_JOB_NAME
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
from litellm.proxy._types import (
DB_CONNECTION_ERROR_TYPES,
BaseDailySpendTransaction,
@@ -147,6 +148,13 @@ class DBSpendUpdateWriter:
prisma_client=prisma_client,
)
)
asyncio.create_task(
self._update_tag_db(
response_cost=response_cost,
request_tags=payload.get("request_tags"),
prisma_client=prisma_client,
)
)
if disable_spend_logs is False:
await self._insert_spend_log_to_db(
@@ -325,6 +333,54 @@ class DBSpendUpdateWriter:
)
raise e
async def _update_tag_db(
self,
response_cost: Optional[float],
request_tags: Optional[str],
prisma_client: Optional[PrismaClient],
):
"""
Update spend for all tags in the request.
Args:
response_cost: Cost of the request
request_tags: JSON string of tags list e.g. '["prod-tag", "test-tag"]'
prisma_client: Prisma client instance
"""
try:
if request_tags is None or prisma_client is None:
return
# Parse tags from JSON string
tags = []
if isinstance(request_tags, str):
tags = safe_json_loads(request_tags, default=[])
if not tags:
verbose_proxy_logger.debug(
f"Failed to parse request_tags JSON: {request_tags}"
)
return
elif isinstance(request_tags, list):
tags = request_tags
else:
return
# Update spend for each tag
for tag_name in tags:
if tag_name and isinstance(tag_name, str):
await self.spend_update_queue.add_update(
update=SpendUpdateQueueItem(
entity_type=Litellm_EntityType.TAG,
entity_id=tag_name,
response_cost=response_cost,
)
)
except Exception as e:
verbose_proxy_logger.debug(
f"Update Tag DB failed to execute - {str(e)}\n{traceback.format_exc()}"
)
raise e
async def _insert_spend_log_to_db(
self,
payload: Union[dict, SpendLogsPayload],
@@ -775,6 +831,75 @@ class DBSpendUpdateWriter:
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
### UPDATE TAG TABLE ###
tag_list_transactions = db_spend_update_transactions["tag_list_transactions"]
await DBSpendUpdateWriter._update_entity_spend_in_db(
entity_name="Tag",
transactions=tag_list_transactions,
table_accessor="litellm_tagtable",
where_field="tag_name",
n_retry_times=n_retry_times,
prisma_client=prisma_client,
proxy_logging_obj=proxy_logging_obj,
)
@staticmethod
async def _update_entity_spend_in_db(
entity_name: str,
transactions: Optional[Dict[str, float]],
table_accessor: Any,
where_field: str,
n_retry_times: int,
prisma_client: PrismaClient,
proxy_logging_obj: ProxyLogging,
):
"""
Helper function to update spend for any entity type (team, org, tag, etc).
Args:
entity_name: Name of entity for logging (e.g., "Team", "Org", "Tag")
transactions: Dictionary of {entity_id: response_cost}
table_accessor: Prisma table accessor (e.g., prisma_client.db.litellm_teamtable)
where_field: Field name for where clause (e.g., "team_id", "organization_id", "tag_name")
n_retry_times: Number of retries on failure
prisma_client: Prisma client instance
proxy_logging_obj: Proxy logging object
"""
from litellm.proxy.utils import _raise_failed_update_spend_exception
verbose_proxy_logger.debug(
f"{entity_name} Spend transactions: {transactions}"
)
if transactions is not None and len(transactions.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
) as transaction:
async with transaction.batch_() as batcher:
for entity_id, response_cost in transactions.items():
verbose_proxy_logger.debug(
f"Updating spend for {entity_name} {where_field}={entity_id} by {response_cost}"
)
getattr(batcher, table_accessor).update_many(
where={where_field: entity_id},
data={"spend": {"increment": response_cost}},
)
break
except DB_CONNECTION_ERROR_TYPES as e:
if i >= n_retry_times:
_raise_failed_update_spend_exception(
e=e,
start_time=start_time,
proxy_logging_obj=proxy_logging_obj,
)
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
# fmt: off
@overload
@@ -380,6 +380,7 @@ class RedisUpdateBuffer:
team_list_transactions={},
team_member_list_transactions={},
org_list_transactions={},
tag_list_transactions={},
)
# Define the transaction fields to process
@@ -390,6 +391,7 @@ class RedisUpdateBuffer:
"team_list_transactions",
"team_member_list_transactions",
"org_list_transactions",
"tag_list_transactions",
]
# Loop through each transaction and combine the values
@@ -135,6 +135,7 @@ class SpendUpdateQueue(BaseUpdateQueue):
team_list_transactions={},
team_member_list_transactions={},
org_list_transactions={},
tag_list_transactions={},
)
# Map entity types to their corresponding transaction dictionary keys
@@ -145,6 +146,7 @@ class SpendUpdateQueue(BaseUpdateQueue):
Litellm_EntityType.TEAM: "team_list_transactions",
Litellm_EntityType.TEAM_MEMBER: "team_member_list_transactions",
Litellm_EntityType.ORGANIZATION: "org_list_transactions",
Litellm_EntityType.TAG: "tag_list_transactions",
}
for update in updates:
@@ -192,6 +194,10 @@ class SpendUpdateQueue(BaseUpdateQueue):
transactions_dict = db_spend_update_transactions[
"org_list_transactions"
]
elif dict_key == "tag_list_transactions":
transactions_dict = db_spend_update_transactions[
"tag_list_transactions"
]
else:
continue
@@ -1,7 +1,7 @@
import asyncio
import traceback
from datetime import datetime
from typing import Any, Optional, Union, cast
from typing import Any, List, Optional, Union, cast
import litellm
from litellm._logging import verbose_proxy_logger
@@ -131,6 +131,11 @@ class _ProxyDBLogger(CustomLogger):
if sl_object is not None
else kwargs.get("response_cost", None)
)
tags: Optional[List[str]] = (
sl_object.get("request_tags", None)
if sl_object is not None
else None
)
if response_cost is not None:
user_api_key = metadata.get("user_api_key", None)
@@ -172,6 +177,7 @@ class _ProxyDBLogger(CustomLogger):
response_cost=response_cost,
team_id=team_id,
parent_otel_span=parent_otel_span,
tags=tags,
)
)
@@ -11,12 +11,12 @@ Endpoints for /organization operations
#### ORGANIZATION MANAGEMENT ####
from litellm._uuid import uuid
from typing import List, Optional, Tuple
from fastapi import APIRouter, Depends, HTTPException, Request, status
from litellm._logging import verbose_proxy_logger
from litellm._uuid import uuid
from litellm.proxy._types import *
from litellm.proxy.auth.auth_checks import can_user_call_model
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
@@ -29,6 +29,7 @@ from litellm.proxy.management_helpers.object_permission_utils import (
)
from litellm.proxy.management_helpers.utils import (
get_new_internal_user_defaults,
handle_budget_for_entity,
management_endpoint_wrapper,
)
from litellm.proxy.utils import PrismaClient
@@ -168,30 +169,17 @@ async def new_organization(
except Exception:
pass
if data.budget_id is None:
"""
Every organization needs a budget attached.
If none provided, create one based on provided values
"""
budget_params = LiteLLM_BudgetTable.model_fields.keys()
# Only include Budget Params when creating an entry in litellm_budgettable
_json_data = data.json(exclude_none=True)
_budget_data = {k: v for k, v in _json_data.items() if k in budget_params}
budget_row = LiteLLM_BudgetTable(**_budget_data)
new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True))
_budget = await prisma_client.db.litellm_budgettable.create(
data={
**new_budget, # type: ignore
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
}
) # type: ignore
data.budget_id = _budget.budget_id
# Handle budget creation/assignment using common helper
budget_id = await handle_budget_for_entity(
data=data,
existing_budget_id=None,
user_api_key_dict=user_api_key_dict,
prisma_client=prisma_client,
litellm_proxy_admin_name=litellm_proxy_admin_name,
)
if budget_id is not None:
data.budget_id = budget_id
## Handle Object Permission - MCP, Vector Stores etc.
object_permission_id = await _set_object_permission(
@@ -322,20 +310,22 @@ async def update_organization(
existing_organization_row=existing_organization_row,
)
# Handle budget updates if budget fields are provided
budget_fields = {k: v for k, v in data.model_dump().items()
if k in LiteLLM_BudgetTable.model_fields.keys() and v is not None}
from litellm.proxy.proxy_server import litellm_proxy_admin_name
# Handle budget updates using common helper
budget_id = await handle_budget_for_entity(
data=data,
existing_budget_id=existing_organization_row.budget_id,
user_api_key_dict=user_api_key_dict,
prisma_client=prisma_client,
litellm_proxy_admin_name=litellm_proxy_admin_name,
)
if budget_fields and existing_organization_row.budget_id:
await update_budget(
budget_obj=BudgetNewRequest(
budget_id=existing_organization_row.budget_id,
**budget_fields
),
user_api_key_dict=user_api_key_dict,
)
# Update budget_id if it changed
if budget_id != existing_organization_row.budget_id:
updated_organization_row["budget_id"] = budget_id
# Remove budget fields from organization update data
# Remove budget fields from organization update data (they're handled via budget table)
for field in LiteLLM_BudgetTable.model_fields.keys():
updated_organization_row.pop(field, None)
@@ -11,7 +11,6 @@ All /tag management endpoints
"""
import asyncio
import datetime
import json
from typing import TYPE_CHECKING, Dict, List, Optional
@@ -25,6 +24,7 @@ from litellm.proxy.management_endpoints.common_daily_activity import (
SpendAnalyticsPaginatedResponse,
get_daily_activity,
)
from litellm.proxy.management_helpers.utils import handle_budget_for_entity
from litellm.types.tag_management import (
LiteLLM_DailyTagSpendTable,
TagConfig,
@@ -53,69 +53,6 @@ async def _get_model_names(prisma_client, model_ids: list) -> Dict[str, str]:
return {}
async def _get_tags_config(prisma_client) -> Dict[str, TagConfig]:
"""Helper function to get tags config from db"""
try:
tags_config = await prisma_client.db.litellm_config.find_unique(
where={"param_name": "tags_config"}
)
if tags_config is None:
return {}
# Convert from JSON if needed
if isinstance(tags_config.param_value, str):
config_dict = json.loads(tags_config.param_value)
else:
config_dict = tags_config.param_value or {}
# For each tag, get the model names
for tag_name, tag_config in config_dict.items():
if isinstance(tag_config, dict) and tag_config.get("models"):
model_info = await _get_model_names(prisma_client, tag_config["models"])
tag_config["model_info"] = model_info
return config_dict
except Exception:
return {}
async def _save_tags_config(prisma_client, tags_config: Dict[str, TagConfig]):
"""Helper function to save tags config to db"""
try:
verbose_proxy_logger.debug(f"Saving tags config: {tags_config}")
# Convert TagConfig objects to dictionaries
tags_config_dict = {}
for name, tag in tags_config.items():
if isinstance(tag, TagConfig):
tag_dict = tag.model_dump()
# Remove model_info before saving as it will be dynamically generated
if "model_info" in tag_dict:
del tag_dict["model_info"]
tags_config_dict[name] = tag_dict
else:
# If it's already a dict, remove model_info
tag_copy = tag.copy()
if "model_info" in tag_copy:
del tag_copy["model_info"]
tags_config_dict[name] = tag_copy
json_tags_config = json.dumps(tags_config_dict, default=str)
verbose_proxy_logger.debug(f"JSON tags config: {json_tags_config}")
await prisma_client.db.litellm_config.upsert(
where={"param_name": "tags_config"},
data={
"create": {
"param_name": "tags_config",
"param_value": json_tags_config,
},
"update": {"param_value": json_tags_config},
},
)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error saving tags config: {str(e)}"
)
async def get_deployments_by_model(
model: str, llm_router: "Router"
) -> List["Deployment"]:
@@ -159,9 +96,23 @@ async def new_tag(
- name: str - The name of the tag
- description: Optional[str] - Description of what this tag represents
- models: List[str] - List of either 'model_id' or 'model_name' allowed for this tag
- budget_id: Optional[str] - The id for a budget (tpm/rpm/max budget) for the tag
### IF NO BUDGET ID - CREATE ONE WITH THESE PARAMS ###
- max_budget: Optional[float] - Max budget for tag
- tpm_limit: Optional[int] - Max tpm limit for tag
- rpm_limit: Optional[int] - Max rpm limit for tag
- max_parallel_requests: Optional[int] - Max parallel requests for tag
- soft_budget: Optional[float] - Get a slack alert when this soft budget is reached
- model_max_budget: Optional[dict] - Max budget for a specific model
- budget_duration: Optional[str] - Frequency of resetting tag budget
"""
from litellm.proxy._types import CommonProxyErrors
from litellm.proxy.proxy_server import llm_router, prisma_client
from litellm.proxy.proxy_server import (
litellm_proxy_admin_name,
llm_router,
prisma_client,
)
if prisma_client is None:
raise HTTPException(
@@ -172,29 +123,38 @@ async def new_tag(
status_code=500, detail=CommonProxyErrors.no_llm_router.value
)
try:
# Get existing tags config
tags_config = await _get_tags_config(prisma_client)
# Check if tag already exists
if tag.name in tags_config:
existing_tag = await prisma_client.db.litellm_tagtable.find_unique(
where={"tag_name": tag.name}
)
if existing_tag is not None:
raise HTTPException(
status_code=400, detail=f"Tag {tag.name} already exists"
)
# Add new tag
tags_config[tag.name] = TagConfig(
name=tag.name,
description=tag.description,
models=tag.models,
created_at=str(datetime.datetime.now()),
updated_at=str(datetime.datetime.now()),
created_by=user_api_key_dict.user_id,
# Handle budget creation/assignment using common helper
budget_id = await handle_budget_for_entity(
data=tag,
existing_budget_id=None,
user_api_key_dict=user_api_key_dict,
prisma_client=prisma_client,
litellm_proxy_admin_name=litellm_proxy_admin_name,
)
# Save updated config
await _save_tags_config(
prisma_client=prisma_client,
tags_config=tags_config,
# Get model names for model_info
model_info = await _get_model_names(prisma_client, tag.models or [])
# Create new tag in database
new_tag_record = await prisma_client.db.litellm_tagtable.create(
data={
"tag_name": tag.name,
"description": tag.description,
"models": tag.models or [],
"model_info": json.dumps(model_info),
"spend": 0.0,
"budget_id": budget_id,
"created_by": user_api_key_dict.user_id,
}
)
# Update models with new tag
@@ -213,13 +173,20 @@ async def new_tag(
)
await asyncio.gather(*tasks)
# Get model names for response
model_info = await _get_model_names(prisma_client, tag.models or [])
tags_config[tag.name].model_info = model_info
# Build response
tag_config = TagConfig(
name=new_tag_record.tag_name,
description=new_tag_record.description,
models=new_tag_record.models,
model_info=model_info,
created_at=new_tag_record.created_at.isoformat(),
updated_at=new_tag_record.updated_at.isoformat(),
created_by=new_tag_record.created_by,
)
return {
"message": f"Tag {tag.name} created successfully",
"tag": tags_config[tag.name],
"tag": tag_config,
}
except Exception as e:
verbose_proxy_logger.exception(f"Error creating tag: {str(e)}")
@@ -264,6 +231,16 @@ async def update_tag(
- name: str - The name of the tag to update
- description: Optional[str] - Updated description
- models: List[str] - Updated list of allowed LLM models
- budget_id: Optional[str] - The id for a budget to associate with the tag
### BUDGET UPDATE PARAMS ###
- max_budget: Optional[float] - Max budget for tag
- tpm_limit: Optional[int] - Max tpm limit for tag
- rpm_limit: Optional[int] - Max rpm limit for tag
- max_parallel_requests: Optional[int] - Max parallel requests for tag
- soft_budget: Optional[float] - Get a slack alert when this soft budget is reached
- model_max_budget: Optional[dict] - Max budget for a specific model
- budget_duration: Optional[str] - Frequency of resetting tag budget
"""
from litellm.proxy.proxy_server import prisma_client
@@ -271,35 +248,58 @@ async def update_tag(
raise HTTPException(status_code=500, detail="Database not connected")
try:
# Get existing tags config
tags_config = await _get_tags_config(prisma_client)
# Check if tag exists
if tag.name not in tags_config:
existing_tag = await prisma_client.db.litellm_tagtable.find_unique(
where={"tag_name": tag.name}
)
if existing_tag is None:
raise HTTPException(status_code=404, detail=f"Tag {tag.name} not found")
# Update tag
tag_config_dict = dict(tags_config[tag.name])
tag_config_dict.update(
{
"description": tag.description,
"models": tag.models,
"updated_at": str(datetime.datetime.now()),
"updated_by": user_api_key_dict.user_id,
}
from litellm.proxy.proxy_server import litellm_proxy_admin_name
# Handle budget updates using common helper
budget_id = await handle_budget_for_entity(
data=tag,
existing_budget_id=existing_tag.budget_id,
user_api_key_dict=user_api_key_dict,
prisma_client=prisma_client,
litellm_proxy_admin_name=litellm_proxy_admin_name,
)
tags_config[tag.name] = TagConfig(**tag_config_dict)
# Save updated config
await _save_tags_config(prisma_client, tags_config)
# Get model names for response
# Get model names for model_info
model_info = await _get_model_names(prisma_client, tag.models or [])
tags_config[tag.name].model_info = model_info
# Prepare update data
update_data = {
"description": tag.description,
"models": tag.models or [],
"model_info": json.dumps(model_info),
}
# Add budget_id if it changed
if budget_id != existing_tag.budget_id:
update_data["budget_id"] = budget_id
# Update tag in database
updated_tag_record = await prisma_client.db.litellm_tagtable.update(
where={"tag_name": tag.name},
data=update_data,
)
# Build response
tag_config = TagConfig(
name=updated_tag_record.tag_name,
description=updated_tag_record.description,
models=updated_tag_record.models,
model_info=model_info,
created_at=updated_tag_record.created_at.isoformat(),
updated_at=updated_tag_record.updated_at.isoformat(),
created_by=updated_tag_record.created_by,
)
return {
"message": f"Tag {tag.name} updated successfully",
"tag": tags_config[tag.name],
"tag": tag_config,
}
except Exception as e:
verbose_proxy_logger.exception(f"Error updating tag: {str(e)}")
@@ -327,18 +327,47 @@ async def info_tag(
raise HTTPException(status_code=500, detail="Database not connected")
try:
tags_config = await _get_tags_config(prisma_client)
# Filter tags based on requested names
requested_tags = {name: tags_config.get(name) for name in data.names}
# Query tags from database with budget info
tag_records = await prisma_client.db.litellm_tagtable.find_many(
where={"tag_name": {"in": data.names}},
include={"litellm_budget_table": True},
)
# Check if any requested tags don't exist
missing_tags = [name for name in data.names if name not in tags_config]
found_tag_names = {tag.tag_name for tag in tag_records}
missing_tags = [name for name in data.names if name not in found_tag_names]
if missing_tags:
raise HTTPException(
status_code=404, detail=f"Tags not found: {missing_tags}"
)
# Build response
requested_tags = {}
for tag_record in tag_records:
# Parse model_info from JSON
model_info = {}
if tag_record.model_info:
if isinstance(tag_record.model_info, str):
model_info = json.loads(tag_record.model_info)
else:
model_info = tag_record.model_info
tag_dict = {
"name": tag_record.tag_name,
"description": tag_record.description,
"models": tag_record.models,
"model_info": model_info,
"created_at": tag_record.created_at.isoformat(),
"updated_at": tag_record.updated_at.isoformat(),
"created_by": tag_record.created_by,
}
# Add budget info if available
if hasattr(tag_record, "litellm_budget_table") and tag_record.litellm_budget_table:
tag_dict["litellm_budget_table"] = tag_record.litellm_budget_table
requested_tags[tag_record.tag_name] = tag_dict
return requested_tags
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@@ -348,13 +377,12 @@ async def info_tag(
"/tag/list",
tags=["tag management"],
dependencies=[Depends(user_api_key_auth)],
response_model=List[TagConfig],
)
async def list_tags(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
List all available tags.
List all available tags with their budget information.
"""
from litellm.proxy.proxy_server import prisma_client
@@ -363,8 +391,37 @@ async def list_tags(
try:
## QUERY STORED TAGS ##
tags_config = await _get_tags_config(prisma_client)
list_of_tags = list(tags_config.values())
tag_records = await prisma_client.db.litellm_tagtable.find_many(
include={"litellm_budget_table": True}
)
stored_tag_names = set()
list_of_tags = []
for tag_record in tag_records:
stored_tag_names.add(tag_record.tag_name)
# Parse model_info from JSON
model_info = {}
if tag_record.model_info:
if isinstance(tag_record.model_info, str):
model_info = json.loads(tag_record.model_info)
else:
model_info = tag_record.model_info
tag_dict = {
"name": tag_record.tag_name,
"description": tag_record.description,
"models": tag_record.models,
"model_info": model_info,
"created_at": tag_record.created_at.isoformat(),
"updated_at": tag_record.updated_at.isoformat(),
"created_by": tag_record.created_by,
}
# Add budget info if available
if hasattr(tag_record, "litellm_budget_table") and tag_record.litellm_budget_table:
tag_dict["litellm_budget_table"] = tag_record.litellm_budget_table
list_of_tags.append(tag_dict)
## QUERY DYNAMIC TAGS ##
dynamic_tags = await prisma_client.db.litellm_dailytagspend.find_many(
@@ -377,15 +434,15 @@ async def list_tags(
]
dynamic_tag_config = [
TagConfig(
name=tag.tag,
description="This is just a spend tag that was passed dynamically in a request. It does not control any LLM models.",
models=None,
created_at=tag.created_at.isoformat(),
updated_at=tag.updated_at.isoformat(),
)
{
"name": tag.tag,
"description": "This is just a spend tag that was passed dynamically in a request. It does not control any LLM models.",
"models": None,
"created_at": tag.created_at.isoformat(),
"updated_at": tag.updated_at.isoformat(),
}
for tag in dynamic_tags_list
if tag.tag not in tags_config
if tag.tag not in stored_tag_names
]
return list_of_tags + dynamic_tag_config
@@ -414,18 +471,15 @@ async def delete_tag(
raise HTTPException(status_code=500, detail="Database not connected")
try:
# Get existing tags config
tags_config = await _get_tags_config(prisma_client)
# Check if tag exists
if data.name not in tags_config:
existing_tag = await prisma_client.db.litellm_tagtable.find_unique(
where={"tag_name": data.name}
)
if existing_tag is None:
raise HTTPException(status_code=404, detail=f"Tag {data.name} not found")
# Delete tag
del tags_config[data.name]
# Save updated config
await _save_tags_config(prisma_client, tags_config)
# Delete tag from database
await prisma_client.db.litellm_tagtable.delete(where={"tag_name": data.name})
return {"message": f"Tag {data.name} deleted successfully"}
except Exception as e:
+85
View File
@@ -10,10 +10,12 @@ import litellm
from litellm._logging import verbose_logger
from litellm._uuid import uuid
from litellm.proxy._types import ( # key request types; user request types; team request types; customer request types
BudgetNewRequest,
DeleteCustomerRequest,
DeleteTeamRequest,
DeleteUserRequest,
KeyRequest,
LiteLLM_BudgetTable,
LiteLLM_TeamMembership,
LiteLLM_UserTable,
ManagementEndpointLoggingPayload,
@@ -53,6 +55,89 @@ def get_new_internal_user_defaults(
return non_null_dict
async def handle_budget_for_entity(
data,
existing_budget_id: Optional[str],
user_api_key_dict: UserAPIKeyAuth,
prisma_client: PrismaClient,
litellm_proxy_admin_name: str,
) -> Optional[str]:
"""
Common helper to handle budget creation/updates for entities (organizations, tags, etc).
This function:
1. Creates a new budget if budget_id is None but budget fields are provided
2. Updates an existing budget if budget fields are provided and budget_id exists
3. Returns the budget_id to use (existing or newly created)
Args:
data: The request object (e.g., TagNewRequest, NewOrganizationRequest, etc.) containing budget fields
existing_budget_id: The existing budget_id if updating an entity, None if creating new
user_api_key_dict: User authentication info
prisma_client: Database client
litellm_proxy_admin_name: Admin name for audit trail
Returns:
Optional[str]: The budget_id to use, or None if no budget was created/updated
"""
from litellm.proxy.management_endpoints.budget_management_endpoints import (
update_budget,
)
# Get all budget field names
budget_params = LiteLLM_BudgetTable.model_fields.keys()
# Extract budget fields from data
_json_data = data.model_dump(exclude_none=True) if hasattr(data, "model_dump") else data
_budget_data = {k: v for k, v in _json_data.items() if k in budget_params}
# Check if budget_id is explicitly provided in the data
data_budget_id = getattr(data, "budget_id", None)
# Case 1: Creating new entity - no existing budget_id
if existing_budget_id is None:
if data_budget_id is not None:
# Use the provided budget_id
return data_budget_id
elif _budget_data:
# Create a new budget with the provided fields
budget_row = LiteLLM_BudgetTable(**_budget_data)
new_budget_data = prisma_client.jsonify_object(
budget_row.model_dump(exclude_none=True)
)
_budget = await prisma_client.db.litellm_budgettable.create(
data={
**new_budget_data, # type: ignore
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
}
) # type: ignore
return _budget.budget_id
else:
# No budget fields provided, no budget to create
return None
# Case 2: Updating existing entity - has existing budget_id
else:
# If budget fields are provided, update the existing budget
if _budget_data:
await update_budget(
budget_obj=BudgetNewRequest(
budget_id=existing_budget_id, **_budget_data
),
user_api_key_dict=user_api_key_dict,
)
# If a different budget_id is explicitly provided, use that instead
if data_budget_id is not None and data_budget_id != existing_budget_id:
return data_budget_id
# Otherwise, keep using the existing budget_id
return existing_budget_id
async def add_new_member(
new_member: Member,
max_budget_in_team: Optional[float],
+57 -4
View File
@@ -33,9 +33,9 @@ from litellm.constants import (
AIOHTTP_TTL_DNS_CACHE,
BASE_MCP_ROUTE,
DEFAULT_MAX_RECURSE_DEPTH,
DEFAULT_SLACK_ALERTING_THRESHOLD,
DEFAULT_SHARED_HEALTH_CHECK_TTL,
DEFAULT_SHARED_HEALTH_CHECK_LOCK_TTL,
DEFAULT_SHARED_HEALTH_CHECK_TTL,
DEFAULT_SLACK_ALERTING_THRESHOLD,
LITELLM_EMBEDDING_PROVIDERS_SUPPORTING_INPUT_ARRAY_OF_TOKENS,
LITELLM_SETTINGS_SAFE_DB_OVERRIDES,
)
@@ -1157,6 +1157,7 @@ async def update_cache( # noqa: PLR0915
team_id: Optional[str],
response_cost: Optional[float],
parent_otel_span: Optional[Span], # type: ignore
tags: Optional[List[str]] = None,
):
"""
Use this to update the cache with new user spend.
@@ -1377,6 +1378,50 @@ async def update_cache( # noqa: PLR0915
f"An error occurred updating end user cache: {str(e)}"
)
### UPDATE TAG SPEND ###
async def _update_tag_cache():
"""
Update the tag cache with the new spend.
"""
if tags is None or response_cost is None:
return
try:
for tag_name in tags:
if not tag_name or not isinstance(tag_name, str):
continue
cache_key = f"tag:{tag_name}"
# Fetch the existing tag object from cache
existing_tag_obj = await user_api_key_cache.async_get_cache(key=cache_key)
if existing_tag_obj is None:
# do nothing if tag not in api key cache
continue
verbose_proxy_logger.debug(
f"_update_tag_cache: existing spend for tag={tag_name}: {existing_tag_obj}; response_cost: {response_cost}"
)
if isinstance(existing_tag_obj, dict):
existing_spend = existing_tag_obj.get("spend", 0) or 0
else:
existing_spend = getattr(existing_tag_obj, "spend", 0) or 0
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
# Update the spend column for the given tag
if isinstance(existing_tag_obj, dict):
existing_tag_obj["spend"] = new_spend
values_to_update_in_cache.append((cache_key, existing_tag_obj))
else:
existing_tag_obj.spend = new_spend
values_to_update_in_cache.append((cache_key, existing_tag_obj))
except Exception as e:
verbose_proxy_logger.exception(
f"An error occurred updating tag cache: {str(e)}"
)
if token is not None and response_cost is not None:
await _update_key_cache(token=token, response_cost=response_cost)
@@ -1389,6 +1434,9 @@ async def update_cache( # noqa: PLR0915
if team_id is not None:
await _update_team_cache()
if tags is not None:
await _update_tag_cache()
asyncio.create_task(
user_api_key_cache.async_set_cache_pipeline(
cache_list=values_to_update_in_cache,
@@ -1431,7 +1479,9 @@ async def _run_background_health_check():
# Initialize shared health check manager if Redis is available and feature is enabled
shared_health_manager = None
if use_shared_health_check and redis_usage_cache is not None:
from litellm.proxy.health_check_utils.shared_health_check_manager import SharedHealthCheckManager
from litellm.proxy.health_check_utils.shared_health_check_manager import (
SharedHealthCheckManager,
)
shared_health_manager = SharedHealthCheckManager(
redis_cache=redis_usage_cache,
health_check_ttl=DEFAULT_SHARED_HEALTH_CHECK_TTL,
@@ -1451,10 +1501,13 @@ async def _run_background_health_check():
]
# Use shared health check if available, otherwise fall back to direct health check
# Convert health_check_details to bool for perform_shared_health_check (defaults to True if None)
details_bool = health_check_details if health_check_details is not None else True
if shared_health_manager is not None:
try:
healthy_endpoints, unhealthy_endpoints = await shared_health_manager.perform_shared_health_check(
model_list=_llm_model_list, details=health_check_details
model_list=_llm_model_list, details=details_bool
)
except Exception as e:
verbose_proxy_logger.error(
+15
View File
@@ -25,6 +25,7 @@ model LiteLLM_BudgetTable {
organization LiteLLM_OrganizationTable[] // multiple orgs can have the same budget
keys LiteLLM_VerificationToken[] // multiple keys can have the same budget
end_users LiteLLM_EndUserTable[] // multiple end-users can have the same budget
tags LiteLLM_TagTable[] // multiple tags can have the same budget
team_membership LiteLLM_TeamMembership[] // budgets of Users within a Team
organization_membership LiteLLM_OrganizationMembership[] // budgets of Users within a Organization
}
@@ -245,6 +246,20 @@ model LiteLLM_EndUserTable {
blocked Boolean @default(false)
}
// Track tags with budgets and spend
model LiteLLM_TagTable {
tag_name String @id
description String?
models String[]
model_info Json? // maps model_id to model_name
spend Float @default(0.0)
budget_id String?
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
created_at DateTime @default(now()) @map("created_at")
created_by String?
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
}
// store proxy config.yaml
model LiteLLM_Config {
param_name String @id
+18 -2
View File
@@ -18,11 +18,27 @@ class TagConfig(TagBase):
class TagNewRequest(TagBase):
pass
budget_id: Optional[str] = None
# Budget fields - if budget_id is None, create a new budget with these params
max_budget: Optional[float] = None
soft_budget: Optional[float] = None
max_parallel_requests: Optional[int] = None
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
model_max_budget: Optional[Dict] = None
budget_duration: Optional[str] = None
class TagUpdateRequest(TagBase):
pass
budget_id: Optional[str] = None
# Budget fields - if provided, will update the budget
max_budget: Optional[float] = None
soft_budget: Optional[float] = None
max_parallel_requests: Optional[int] = None
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
model_max_budget: Optional[Dict] = None
budget_duration: Optional[str] = None
class TagDeleteRequest(BaseModel):
+15
View File
@@ -25,6 +25,7 @@ model LiteLLM_BudgetTable {
organization LiteLLM_OrganizationTable[] // multiple orgs can have the same budget
keys LiteLLM_VerificationToken[] // multiple keys can have the same budget
end_users LiteLLM_EndUserTable[] // multiple end-users can have the same budget
tags LiteLLM_TagTable[] // multiple tags can have the same budget
team_membership LiteLLM_TeamMembership[] // budgets of Users within a Team
organization_membership LiteLLM_OrganizationMembership[] // budgets of Users within a Organization
}
@@ -245,6 +246,20 @@ model LiteLLM_EndUserTable {
blocked Boolean @default(false)
}
// Track tags with budgets and spend
model LiteLLM_TagTable {
tag_name String @id
description String?
models String[]
model_info Json? // maps model_id to model_name
spend Float @default(0.0)
budget_id String?
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
created_at DateTime @default(now()) @map("created_at")
created_by String?
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
}
// store proxy config.yaml
model LiteLLM_Config {
param_name String @id
@@ -26,9 +26,9 @@ from litellm.proxy._types import (
from litellm.proxy.auth.auth_checks import (
ExperimentalUIJWTToken,
_can_object_call_vector_stores,
_get_team_db_check,
get_user_object,
vector_store_access_check,
_get_team_db_check,
)
from litellm.proxy.common_utils.encrypt_decrypt_utils import decrypt_value_helper
from litellm.utils import get_utc_datetime
@@ -567,3 +567,141 @@ def test_can_object_call_model_no_access_to_alias_or_underlying():
assert exc_info.value.type == ProxyErrorTypes.key_model_access_denied
assert "key not allowed to access model" in str(exc_info.value.message)
assert "my-fake-gpt" in str(exc_info.value.message)
# Tag Budget Enforcement Tests
@pytest.mark.asyncio
async def test_get_tag_objects_batch():
"""
Test batch fetching of tags validates:
- Cached tags are fetched from cache (no DB call for them)
- Uncached tags are fetched in ONE batch DB query
- After fetching, uncached tags are cached
"""
from litellm.proxy._types import LiteLLM_TagTable
from litellm.proxy.auth.auth_checks import get_tag_objects_batch
mock_prisma = MagicMock()
mock_cache = MagicMock()
mock_proxy_logging = MagicMock()
# Simulate 5 tags: 2 cached, 3 uncached
tag_names = ["cached-1", "uncached-1", "cached-2", "uncached-2", "uncached-3"]
# Mock cached tags
cached_tag_1 = {
"tag_name": "cached-1",
"spend": 10.0,
"models": [],
"litellm_budget_table": None,
}
cached_tag_2 = {
"tag_name": "cached-2",
"spend": 20.0,
"models": [],
"litellm_budget_table": None,
}
# Mock DB response for uncached tags
uncached_tag_1 = MagicMock()
uncached_tag_1.tag_name = "uncached-1"
uncached_tag_1.spend = 30.0
uncached_tag_1.models = []
uncached_tag_1.litellm_budget_table = None
uncached_tag_1.dict = MagicMock(
return_value={
"tag_name": "uncached-1",
"spend": 30.0,
"models": [],
"litellm_budget_table": None,
}
)
uncached_tag_2 = MagicMock()
uncached_tag_2.tag_name = "uncached-2"
uncached_tag_2.spend = 40.0
uncached_tag_2.models = []
uncached_tag_2.litellm_budget_table = None
uncached_tag_2.dict = MagicMock(
return_value={
"tag_name": "uncached-2",
"spend": 40.0,
"models": [],
"litellm_budget_table": None,
}
)
uncached_tag_3 = MagicMock()
uncached_tag_3.tag_name = "uncached-3"
uncached_tag_3.spend = 50.0
uncached_tag_3.models = []
uncached_tag_3.litellm_budget_table = None
uncached_tag_3.dict = MagicMock(
return_value={
"tag_name": "uncached-3",
"spend": 50.0,
"models": [],
"litellm_budget_table": None,
}
)
# Mock cache behavior - return cached tags, None for uncached
async def mock_get_cache(key):
if key == "tag:cached-1":
return cached_tag_1
elif key == "tag:cached-2":
return cached_tag_2
else:
return None
mock_cache.async_get_cache = AsyncMock(side_effect=mock_get_cache)
mock_cache.async_set_cache = AsyncMock()
# Mock DB to return all uncached tags in ONE query
mock_prisma.db.litellm_tagtable.find_many = AsyncMock(
return_value=[uncached_tag_1, uncached_tag_2, uncached_tag_3]
)
# Call batch fetch
tag_objects = await get_tag_objects_batch(
tag_names=tag_names,
prisma_client=mock_prisma,
user_api_key_cache=mock_cache,
proxy_logging_obj=mock_proxy_logging,
)
# Verify results
assert len(tag_objects) == 5
assert "cached-1" in tag_objects
assert "cached-2" in tag_objects
assert "uncached-1" in tag_objects
assert "uncached-2" in tag_objects
assert "uncached-3" in tag_objects
# Verify cached tags have correct values
assert tag_objects["cached-1"].spend == 10.0
assert tag_objects["cached-2"].spend == 20.0
# Verify uncached tags have correct values
assert tag_objects["uncached-1"].spend == 30.0
assert tag_objects["uncached-2"].spend == 40.0
assert tag_objects["uncached-3"].spend == 50.0
# Verify DB was called ONCE with all 3 uncached tags
mock_prisma.db.litellm_tagtable.find_many.assert_called_once()
call_args = mock_prisma.db.litellm_tagtable.find_many.call_args
assert call_args.kwargs["where"]["tag_name"]["in"] == [
"uncached-1",
"uncached-2",
"uncached-3",
]
# Verify uncached tags were cached after fetching
assert mock_cache.async_set_cache.call_count == 3
cache_calls = mock_cache.async_set_cache.call_args_list
cached_keys = [call.kwargs["key"] for call in cache_calls]
assert "tag:uncached-1" in cached_keys
assert "tag:uncached-2" in cached_keys
assert "tag:uncached-3" in cached_keys
@@ -14,13 +14,17 @@ sys.path.insert(
import litellm
from litellm.proxy._types import ProxyException
from litellm.proxy.common_utils.http_parsing_utils import (
_read_request_body,
_safe_get_request_headers,
_safe_get_request_parsed_body,
_safe_get_request_query_params,
_safe_set_request_parsed_body,
get_form_data,
get_request_body,
get_tags_from_request_body,
)
from litellm.proxy._types import ProxyException
@pytest.mark.asyncio
@@ -285,3 +289,96 @@ async def test_get_form_data():
# Note: In a real MultiDict, both values would be present
# But in our mock dictionary the second value overwrites the first
assert "segment" in result["timestamp_granularities"]
def test_get_tags_from_request_body_with_metadata_tags():
"""
Test that tags are correctly extracted from request body metadata.
"""
request_body = {
"model": "gpt-4",
"metadata": {
"tags": ["tag1", "tag2", "tag3"]
}
}
result = get_tags_from_request_body(request_body=request_body)
assert result == ["tag1", "tag2", "tag3"]
def test_get_tags_from_request_body_with_litellm_metadata_tags():
"""
Test that tags are correctly extracted from request body when using litellm_metadata.
"""
request_body = {
"model": "gpt-4",
"litellm_metadata": {
"tags": ["tag1", "tag2", "tag3"]
}
}
result = get_tags_from_request_body(request_body=request_body)
assert result == ["tag1", "tag2", "tag3"]
def test_get_tags_from_request_body_with_root_tags():
"""
Test that tags are correctly extracted from root level of request body.
"""
request_body = {
"model": "gpt-4",
"tags": ["tag1", "tag2"]
}
result = get_tags_from_request_body(request_body=request_body)
assert result == ["tag1", "tag2"]
def test_get_tags_from_request_body_with_combined_tags():
"""
Test that tags from both metadata and root level are combined.
"""
request_body = {
"model": "gpt-4",
"metadata": {
"tags": ["tag1", "tag2"]
},
"tags": ["tag3", "tag4"]
}
result = get_tags_from_request_body(request_body=request_body)
assert result == ["tag1", "tag2", "tag3", "tag4"]
def test_get_tags_from_request_body_filters_non_strings():
"""
Test that non-string values in tags list are filtered out.
"""
request_body = {
"model": "gpt-4",
"metadata": {
"tags": ["tag1", 123, "tag2", None, "tag3", {"nested": "dict"}]
}
}
result = get_tags_from_request_body(request_body=request_body)
assert result == ["tag1", "tag2", "tag3"]
def test_get_tags_from_request_body_no_tags():
"""
Test that empty list is returned when no tags are present.
"""
request_body = {
"model": "gpt-4",
"metadata": {}
}
result = get_tags_from_request_body(request_body=request_body)
assert result == []
@@ -135,3 +135,127 @@ async def test_update_daily_spend_with_null_entity_id():
assert create_data["api_requests"] == 1
assert create_data["successful_requests"] == 1
assert create_data["failed_requests"] == 0
# Tag Spend Tracking Tests
@pytest.mark.asyncio
async def test_update_tag_db_with_valid_tags():
"""
Test that _update_tag_db correctly processes valid tags and adds them to the spend update queue.
"""
from litellm.proxy._types import Litellm_EntityType, SpendUpdateQueueItem
writer = DBSpendUpdateWriter()
mock_prisma = MagicMock()
response_cost = 0.05
request_tags = '["prod-tag", "test-tag"]'
writer.spend_update_queue.add_update = AsyncMock()
await writer._update_tag_db(
response_cost=response_cost,
request_tags=request_tags,
prisma_client=mock_prisma,
)
assert writer.spend_update_queue.add_update.call_count == 2
first_call_args = writer.spend_update_queue.add_update.call_args_list[0][1]
assert first_call_args["update"]["entity_type"] == Litellm_EntityType.TAG
assert first_call_args["update"]["entity_id"] == "prod-tag"
assert first_call_args["update"]["response_cost"] == response_cost
second_call_args = writer.spend_update_queue.add_update.call_args_list[1][1]
assert second_call_args["update"]["entity_type"] == Litellm_EntityType.TAG
assert second_call_args["update"]["entity_id"] == "test-tag"
assert second_call_args["update"]["response_cost"] == response_cost
@pytest.mark.asyncio
async def test_update_tag_db_with_list_input():
"""
Test that _update_tag_db correctly handles tags passed as a list instead of JSON string.
"""
writer = DBSpendUpdateWriter()
mock_prisma = MagicMock()
response_cost = 0.1
request_tags = ["tag1", "tag2", "tag3"]
writer.spend_update_queue.add_update = AsyncMock()
await writer._update_tag_db(
response_cost=response_cost,
request_tags=request_tags,
prisma_client=mock_prisma,
)
assert writer.spend_update_queue.add_update.call_count == 3
@pytest.mark.asyncio
async def test_update_tag_db_with_no_tags():
"""
Test that _update_tag_db handles None and empty tags gracefully.
"""
writer = DBSpendUpdateWriter()
mock_prisma = MagicMock()
response_cost = 0.05
writer.spend_update_queue.add_update = AsyncMock()
await writer._update_tag_db(
response_cost=response_cost,
request_tags=None,
prisma_client=mock_prisma,
)
assert writer.spend_update_queue.add_update.call_count == 0
await writer._update_tag_db(
response_cost=response_cost,
request_tags=[],
prisma_client=mock_prisma,
)
assert writer.spend_update_queue.add_update.call_count == 0
@pytest.mark.asyncio
async def test_update_tag_db_with_invalid_json():
"""
Test that _update_tag_db handles invalid JSON gracefully.
"""
writer = DBSpendUpdateWriter()
mock_prisma = MagicMock()
response_cost = 0.05
request_tags = '{"invalid": json}'
writer.spend_update_queue.add_update = AsyncMock()
await writer._update_tag_db(
response_cost=response_cost,
request_tags=request_tags,
prisma_client=mock_prisma,
)
assert writer.spend_update_queue.add_update.call_count == 0
@pytest.mark.asyncio
async def test_update_tag_db_without_prisma_client():
"""
Test that _update_tag_db returns early when prisma_client is None.
"""
writer = DBSpendUpdateWriter()
response_cost = 0.05
request_tags = '["tag1"]'
writer.spend_update_queue.add_update = AsyncMock()
await writer._update_tag_db(
response_cost=response_cost,
request_tags=request_tags,
prisma_client=None,
)
assert writer.spend_update_queue.add_update.call_count == 0
@@ -2184,3 +2184,107 @@ def test_should_load_db_object_with_supported_db_objects():
)
assert proxy_config._should_load_db_object(object_type="prompts") is True
assert proxy_config._should_load_db_object(object_type="model_cost_map") is True
@pytest.mark.asyncio
async def test_tag_cache_update_called():
"""
Test that update_cache updates tag cache when tags are provided.
"""
from litellm.caching.caching import DualCache
from litellm.proxy.proxy_server import user_api_key_cache
cache = DualCache()
setattr(
litellm.proxy.proxy_server,
"user_api_key_cache",
cache,
)
mock_tag_obj = {
"tag_name": "test-tag",
"spend": 10.0,
}
with patch.object(cache, "async_get_cache", new=AsyncMock(return_value=mock_tag_obj)) as mock_get_cache:
with patch.object(cache, "async_set_cache_pipeline", new=AsyncMock()) as mock_set_cache:
await litellm.proxy.proxy_server.update_cache(
token=None,
user_id=None,
end_user_id=None,
team_id=None,
response_cost=5.0,
parent_otel_span=None,
tags=["test-tag"],
)
await asyncio.sleep(0.1)
mock_get_cache.assert_awaited_once_with(key="tag:test-tag")
mock_set_cache.assert_awaited_once()
call_args = mock_set_cache.call_args
cache_list = call_args.kwargs["cache_list"]
assert len(cache_list) == 1
cache_key, cache_value = cache_list[0]
assert cache_key == "tag:test-tag"
assert cache_value["spend"] == 15.0
@pytest.mark.asyncio
async def test_tag_cache_update_multiple_tags():
"""
Test that multiple tags are updated in cache.
"""
from litellm.caching.caching import DualCache
from litellm.proxy.proxy_server import user_api_key_cache
cache = DualCache()
setattr(
litellm.proxy.proxy_server,
"user_api_key_cache",
cache,
)
mock_tag1_obj = {"tag_name": "tag1", "spend": 10.0}
mock_tag2_obj = {"tag_name": "tag2", "spend": 20.0}
async def mock_get_cache_side_effect(key):
if key == "tag:tag1":
return mock_tag1_obj
elif key == "tag:tag2":
return mock_tag2_obj
return None
with patch.object(cache, "async_get_cache", new=AsyncMock(side_effect=mock_get_cache_side_effect)) as mock_get_cache:
with patch.object(cache, "async_set_cache_pipeline", new=AsyncMock()) as mock_set_cache:
await litellm.proxy.proxy_server.update_cache(
token=None,
user_id=None,
end_user_id=None,
team_id=None,
response_cost=5.0,
parent_otel_span=None,
tags=["tag1", "tag2"],
)
await asyncio.sleep(0.1)
assert mock_get_cache.call_count == 2
mock_set_cache.assert_awaited_once()
call_args = mock_set_cache.call_args
cache_list = call_args.kwargs["cache_list"]
assert len(cache_list) == 2
tag_updates = {cache_key: cache_value for cache_key, cache_value in cache_list}
assert "tag:tag1" in tag_updates
assert "tag:tag2" in tag_updates
assert tag_updates["tag:tag1"]["spend"] == 15.0
assert tag_updates["tag:tag2"]["spend"] == 25.0
@@ -0,0 +1,81 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
import { render, screen, waitFor } from "@testing-library/react";
import userEvent from "@testing-library/user-event";
import CreateTagModal from "./CreateTagModal";
describe("CreateTagModal", () => {
const mockOnCancel = vi.fn();
const mockOnSubmit = vi.fn();
const mockAvailableModels = [
{
model_name: "gpt-4",
litellm_params: {
model: "gpt-4",
},
model_info: {
id: "model-123",
},
},
{
model_name: "claude-3",
litellm_params: {
model: "claude-3",
},
model_info: {
id: "model-456",
},
},
];
beforeEach(() => {
vi.clearAllMocks();
});
it("should submit form with correct values when user creates a tag", async () => {
/**
* Tests that filling out the form and clicking submit calls onSubmit with the correct values.
* This is the core functionality of the CreateTagModal component.
*/
const user = userEvent.setup();
render(
<CreateTagModal
visible={true}
onCancel={mockOnCancel}
onSubmit={mockOnSubmit}
availableModels={mockAvailableModels}
/>
);
// Wait for modal to be visible
await waitFor(() => {
expect(screen.getByText("Create New Tag")).toBeInTheDocument();
});
// Fill in the tag name
const tagNameInput = screen.getByLabelText("Tag Name");
await user.type(tagNameInput, "production-tag");
// Fill in the description
const descriptionTextarea = screen.getByLabelText("Description");
await user.type(descriptionTextarea, "Tag for production environment");
// Submit the form
const createButton = screen.getByText("Create Tag");
await user.click(createButton);
// Verify onSubmit was called with correct values
await waitFor(() => {
expect(mockOnSubmit).toHaveBeenCalledWith(
expect.objectContaining({
tag_name: "production-tag",
description: "Tag for production environment",
})
);
});
// Verify onCancel was not called
expect(mockOnCancel).not.toHaveBeenCalled();
});
});
@@ -0,0 +1,153 @@
import React from "react";
import { Button, TextInput, Accordion, AccordionHeader, AccordionBody, Title } from "@tremor/react";
import { Modal, Form, Select as Select2, Tooltip, Input } from "antd";
import { InfoCircleOutlined } from "@ant-design/icons";
import NumericalInput from "../../shared/numerical_input";
import BudgetDurationDropdown from "../../common_components/budget_duration_dropdown";
interface ModelInfo {
model_name: string;
litellm_params: {
model: string;
};
model_info: {
id: string;
};
}
interface CreateTagModalProps {
visible: boolean;
onCancel: () => void;
onSubmit: (values: any) => void;
availableModels: ModelInfo[];
}
const CreateTagModal: React.FC<CreateTagModalProps> = ({
visible,
onCancel,
onSubmit,
availableModels,
}) => {
const [form] = Form.useForm();
const handleFinish = (values: any) => {
onSubmit(values);
form.resetFields();
};
const handleCancel = () => {
form.resetFields();
onCancel();
};
return (
<Modal
title="Create New Tag"
visible={visible}
width={800}
footer={null}
onCancel={handleCancel}
>
<Form
form={form}
onFinish={handleFinish}
labelCol={{ span: 8 }}
wrapperCol={{ span: 16 }}
labelAlign="left"
>
<Form.Item
label="Tag Name"
name="tag_name"
rules={[{ required: true, message: "Please input a tag name" }]}
>
<TextInput />
</Form.Item>
<Form.Item label="Description" name="description">
<Input.TextArea rows={4} />
</Form.Item>
<Form.Item
label={
<span>
Allowed Models{" "}
<Tooltip title="Select which LLMs are allowed to process requests from this tag">
<InfoCircleOutlined style={{ marginLeft: "4px" }} />
</Tooltip>
</span>
}
name="allowed_llms"
>
<Select2 mode="multiple" placeholder="Select LLMs">
{availableModels.map((model) => (
<Select2.Option key={model.model_info.id} value={model.model_info.id}>
<div>
<span>{model.model_name}</span>
<span className="text-gray-400 ml-2">({model.model_info.id})</span>
</div>
</Select2.Option>
))}
</Select2>
</Form.Item>
<Accordion className="mt-4 mb-4">
<AccordionHeader>
<Title className="m-0">Budget & Rate Limits (Optional)</Title>
</AccordionHeader>
<AccordionBody>
<Form.Item
className="mt-4"
label={
<span>
Max Budget (USD){" "}
<Tooltip title="Maximum amount in USD this tag can spend. When reached, requests with this tag will be blocked">
<InfoCircleOutlined style={{ marginLeft: "4px" }} />
</Tooltip>
</span>
}
name="max_budget"
>
<NumericalInput step={0.01} precision={2} width={200} />
</Form.Item>
<Form.Item
className="mt-4"
label={
<span>
Reset Budget{" "}
<Tooltip title="How often the budget should reset. For example, setting 'daily' will reset the budget every 24 hours">
<InfoCircleOutlined style={{ marginLeft: "4px" }} />
</Tooltip>
</span>
}
name="budget_duration"
>
<BudgetDurationDropdown onChange={(value) => form.setFieldValue("budget_duration", value)} />
</Form.Item>
<div className="mt-4 p-3 bg-gray-50 rounded-md border border-gray-200">
<p className="text-sm text-gray-600">
TPM/RPM limits for tags are not currently supported. If you need this feature, please{" "}
<a
href="https://github.com/BerriAI/litellm/issues/new"
target="_blank"
rel="noopener noreferrer"
className="text-blue-600 hover:text-blue-800 underline"
>
create a GitHub issue
</a>
.
</p>
</div>
</AccordionBody>
</Accordion>
<div style={{ textAlign: "right", marginTop: "10px" }}>
<Button type="submit">Create Tag</Button>
</div>
</Form>
</Modal>
);
};
export default CreateTagModal;
@@ -1,14 +1,13 @@
import React, { useState, useEffect } from "react";
import { Icon, Button, Col, Text, Grid, TextInput } from "@tremor/react";
import { Icon, Button, Col, Text, Grid } from "@tremor/react";
import { RefreshIcon } from "@heroicons/react/outline";
import { Modal, Form, Select as Select2, Tooltip, Input } from "antd";
import { InfoCircleOutlined } from "@ant-design/icons";
import TagInfoView from "./tag_info";
import { modelInfoCall } from "../networking";
import { tagCreateCall, tagListCall, tagDeleteCall } from "../networking";
import { Tag } from "./types";
import TagTable from "./TagTable";
import NotificationsManager from "../molecules/notifications_manager";
import CreateTagModal from "./components/CreateTagModal";
interface ModelInfo {
model_name: string;
@@ -34,7 +33,6 @@ const TagManagement: React.FC<TagProps> = ({ accessToken, userID, userRole }) =>
const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false);
const [tagToDelete, setTagToDelete] = useState<string | null>(null);
const [lastRefreshed, setLastRefreshed] = useState("");
const [form] = Form.useForm();
const [availableModels, setAvailableModels] = useState<ModelInfo[]>([]);
const fetchTags = async () => {
@@ -62,10 +60,14 @@ const TagManagement: React.FC<TagProps> = ({ accessToken, userID, userRole }) =>
name: formValues.tag_name,
description: formValues.description,
models: formValues.allowed_llms,
max_budget: formValues.max_budget,
soft_budget: formValues.soft_budget,
tpm_limit: formValues.tpm_limit,
rpm_limit: formValues.rpm_limit,
budget_duration: formValues.budget_duration,
});
NotificationsManager.success("Tag created successfully");
setIsCreateModalVisible(false);
form.resetFields();
fetchTags();
} catch (error) {
console.error("Error creating tag:", error);
@@ -173,63 +175,12 @@ const TagManagement: React.FC<TagProps> = ({ accessToken, userID, userRole }) =>
</Grid>
{/* Create Tag Modal */}
<Modal
title="Create New Tag"
<CreateTagModal
visible={isCreateModalVisible}
width={800}
footer={null}
onCancel={() => {
setIsCreateModalVisible(false);
form.resetFields();
}}
>
<Form
form={form}
onFinish={handleCreate}
labelCol={{ span: 8 }}
wrapperCol={{ span: 16 }}
labelAlign="left"
>
<Form.Item
label="Tag Name"
name="tag_name"
rules={[{ required: true, message: "Please input a tag name" }]}
>
<TextInput />
</Form.Item>
<Form.Item label="Description" name="description">
<Input.TextArea rows={4} />
</Form.Item>
<Form.Item
label={
<span>
Allowed Models{" "}
<Tooltip title="Select which LLMs are allowed to process requests from this tag">
<InfoCircleOutlined style={{ marginLeft: "4px" }} />
</Tooltip>
</span>
}
name="allowed_llms"
>
<Select2 mode="multiple" placeholder="Select LLMs">
{availableModels.map((model) => (
<Select2.Option key={model.model_info.id} value={model.model_info.id}>
<div>
<span>{model.model_name}</span>
<span className="text-gray-400 ml-2">({model.model_info.id})</span>
</div>
</Select2.Option>
))}
</Select2>
</Form.Item>
<div style={{ textAlign: "right", marginTop: "10px" }}>
<Button type="submit">Create Tag</Button>
</div>
</Form>
</Modal>
onCancel={() => setIsCreateModalVisible(false)}
onSubmit={handleCreate}
availableModels={availableModels}
/>
{/* Delete Confirmation Modal */}
{isDeleteModalOpen && (
@@ -1,5 +1,5 @@
import React, { useState, useEffect } from "react";
import { Card, Text, Title, Button, Badge } from "@tremor/react";
import { Card, Text, Title, Button, Badge, Accordion, AccordionHeader, AccordionBody, Title as TremorTitle } from "@tremor/react";
import { Form, Input, Select as Select2, Tooltip } from "antd";
import { InfoCircleOutlined } from "@ant-design/icons";
import { fetchUserModels } from "../organisms/create_key_button";
@@ -7,6 +7,11 @@ import { getModelDisplayName } from "../key_team_helpers/fetch_available_models_
import { tagInfoCall, tagUpdateCall } from "../networking";
import { Tag } from "./types";
import NotificationsManager from "../molecules/notifications_manager";
import NumericalInput from "../shared/numerical_input";
import BudgetDurationDropdown from "../common_components/budget_duration_dropdown";
import { copyToClipboard as utilCopyToClipboard } from "@/utils/dataUtils";
import { CheckIcon, CopyIcon } from "lucide-react";
import { Button as AntdButton } from "antd";
interface TagInfoViewProps {
tagId: string;
@@ -21,6 +26,17 @@ const TagInfoView: React.FC<TagInfoViewProps> = ({ tagId, onClose, accessToken,
const [tagDetails, setTagDetails] = useState<Tag | null>(null);
const [isEditing, setIsEditing] = useState<boolean>(editTag);
const [userModels, setUserModels] = useState<string[]>([]);
const [copiedStates, setCopiedStates] = useState<Record<string, boolean>>({});
const copyToClipboard = async (text: string | null | undefined, key: string) => {
const success = await utilCopyToClipboard(text);
if (success) {
setCopiedStates((prev) => ({ ...prev, [key]: true }));
setTimeout(() => {
setCopiedStates((prev) => ({ ...prev, [key]: false }));
}, 2000);
}
};
const fetchTagDetails = async () => {
if (!accessToken) return;
@@ -34,6 +50,8 @@ const TagInfoView: React.FC<TagInfoViewProps> = ({ tagId, onClose, accessToken,
name: tagData.name,
description: tagData.description,
models: tagData.models,
max_budget: tagData.litellm_budget_table?.max_budget,
budget_duration: tagData.litellm_budget_table?.budget_duration,
});
}
}
@@ -62,6 +80,10 @@ const TagInfoView: React.FC<TagInfoViewProps> = ({ tagId, onClose, accessToken,
name: values.name,
description: values.description,
models: values.models,
max_budget: values.max_budget,
tpm_limit: values.tpm_limit,
rpm_limit: values.rpm_limit,
budget_duration: values.budget_duration,
});
NotificationsManager.success("Tag updated successfully");
setIsEditing(false);
@@ -83,7 +105,23 @@ const TagInfoView: React.FC<TagInfoViewProps> = ({ tagId, onClose, accessToken,
<Button onClick={onClose} className="mb-4">
Back to Tags
</Button>
<Title>Tag Name: {tagDetails.name}</Title>
<div className="flex items-center gap-2">
<Text className="font-medium">Tag Name:</Text>
<span className="font-mono px-2 py-1 bg-gray-100 rounded text-sm border border-gray-200">
{tagDetails.name}
</span>
<AntdButton
type="text"
size="small"
icon={copiedStates["tag-name"] ? <CheckIcon size={12} /> : <CopyIcon size={12} />}
onClick={() => copyToClipboard(tagDetails.name, "tag-name")}
className={`transition-all duration-200 ${
copiedStates["tag-name"]
? "text-green-600 bg-green-50 border-green-200"
: "text-gray-500 hover:text-gray-700 hover:bg-gray-100"
}`}
/>
</div>
<Text className="text-gray-500">{tagDetails.description || "No description"}</Text>
</div>
{is_admin && !isEditing && <Button onClick={() => setIsEditing(true)}>Edit Tag</Button>}
@@ -120,6 +158,56 @@ const TagInfoView: React.FC<TagInfoViewProps> = ({ tagId, onClose, accessToken,
</Select2>
</Form.Item>
<Accordion className="mt-4 mb-4">
<AccordionHeader>
<TremorTitle className="m-0">Budget & Rate Limits</TremorTitle>
</AccordionHeader>
<AccordionBody>
<Form.Item
label={
<span>
Max Budget (USD){" "}
<Tooltip title="Maximum amount in USD this tag can spend">
<InfoCircleOutlined style={{ marginLeft: "4px" }} />
</Tooltip>
</span>
}
name="max_budget"
>
<NumericalInput step={0.01} precision={2} width={200} />
</Form.Item>
<Form.Item
label={
<span>
Reset Budget{" "}
<Tooltip title="How often the budget should reset">
<InfoCircleOutlined style={{ marginLeft: "4px" }} />
</Tooltip>
</span>
}
name="budget_duration"
>
<BudgetDurationDropdown onChange={(value) => form.setFieldValue("budget_duration", value)} />
</Form.Item>
<div className="mt-4 p-3 bg-gray-50 rounded-md border border-gray-200">
<p className="text-sm text-gray-600">
TPM/RPM limits for tags are not currently supported. If you need this feature, please{" "}
<a
href="https://github.com/BerriAI/litellm/issues/new"
target="_blank"
rel="noopener noreferrer"
className="text-blue-600 hover:text-blue-800 underline"
>
create a GitHub issue
</a>
.
</p>
</div>
</AccordionBody>
</Accordion>
<div className="flex justify-end space-x-2">
<Button onClick={() => setIsEditing(false)}>Cancel</Button>
<Button type="submit">Save Changes</Button>
@@ -142,7 +230,7 @@ const TagInfoView: React.FC<TagInfoViewProps> = ({ tagId, onClose, accessToken,
<div>
<Text className="font-medium">Allowed LLMs</Text>
<div className="flex flex-wrap gap-2 mt-2">
{tagDetails.models.length === 0 ? (
{!tagDetails.models || tagDetails.models.length === 0 ? (
<Badge color="red">All Models</Badge>
) : (
tagDetails.models.map((modelId) => (
@@ -163,6 +251,38 @@ const TagInfoView: React.FC<TagInfoViewProps> = ({ tagId, onClose, accessToken,
</div>
</div>
</Card>
{tagDetails.litellm_budget_table && (
<Card>
<Title>Budget & Rate Limits</Title>
<div className="space-y-4 mt-4">
{tagDetails.litellm_budget_table.max_budget !== undefined && tagDetails.litellm_budget_table.max_budget !== null && (
<div>
<Text className="font-medium">Max Budget</Text>
<Text>${tagDetails.litellm_budget_table.max_budget}</Text>
</div>
)}
{tagDetails.litellm_budget_table.budget_duration && (
<div>
<Text className="font-medium">Budget Duration</Text>
<Text>{tagDetails.litellm_budget_table.budget_duration}</Text>
</div>
)}
{tagDetails.litellm_budget_table.tpm_limit !== undefined && tagDetails.litellm_budget_table.tpm_limit !== null && (
<div>
<Text className="font-medium">TPM Limit</Text>
<Text>{tagDetails.litellm_budget_table.tpm_limit.toLocaleString()}</Text>
</div>
)}
{tagDetails.litellm_budget_table.rpm_limit !== undefined && tagDetails.litellm_budget_table.rpm_limit !== null && (
<div>
<Text className="font-medium">RPM Limit</Text>
<Text>{tagDetails.litellm_budget_table.rpm_limit.toLocaleString()}</Text>
</div>
)}
</div>
</Card>
)}
</div>
)}
</div>
@@ -7,6 +7,15 @@ export interface Tag {
updated_at: string;
created_by?: string;
updated_by?: string;
litellm_budget_table?: {
max_budget?: number;
soft_budget?: number;
tpm_limit?: number;
rpm_limit?: number;
max_parallel_requests?: number;
budget_duration?: string;
model_max_budget?: any;
};
}
export interface TagInfoRequest {
@@ -17,12 +26,22 @@ export interface TagNewRequest {
name: string;
description?: string;
models: string[];
max_budget?: number;
soft_budget?: number;
tpm_limit?: number;
rpm_limit?: number;
budget_duration?: string;
}
export interface TagUpdateRequest {
name: string;
description?: string;
models: string[];
max_budget?: number;
soft_budget?: number;
tpm_limit?: number;
rpm_limit?: number;
budget_duration?: string;
}
export interface TagDeleteRequest {