diff --git a/docs/my-website/docs/proxy/tag_budgets.md b/docs/my-website/docs/proxy/tag_budgets.md new file mode 100644 index 000000000..01b82ff8d --- /dev/null +++ b/docs/my-website/docs/proxy/tag_budgets.md @@ -0,0 +1,277 @@ +import Image from '@theme/IdealImage'; +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# Setting Tag Budgets + +Track spend and set budgets for your API requests using tags. Tags allow you to categorize and monitor costs across different cost centers, projects, and departments. + +## Pre-Requisites + +- You must set up a Postgres database (e.g. Supabase, Neon, etc.) + +## What are Tags? + +Tags are labels you can attach to your LLM requests to track and limit spending by category. + +**Common Use Cases:** +- **Cost Center Tracking**: Allocate LLM costs to specific departments or business units (e.g., "engineering", "marketing", "customer-support") +- **Project-based Budgeting**: Set budgets for different projects or initiatives (e.g., "project-alpha", "chatbot-v2") +- **Customer Attribution**: Track spend per customer or client (e.g., "customer-acme", "customer-techcorp") +- **Feature Monitoring**: Monitor costs for specific features (e.g., "feature-chat", "feature-summarization") + +Tags are added to each request in the `metadata` field to track and enforce budget limits. + +## Setting Tag Budgets + +### 1. Create a tag with budget + +Create a tag to represent a cost center, project, or any budget category. Set `max_budget` ($ value allowed) and `budget_duration` (how frequently the budget resets). + +**Example:** Create a tag for your Engineering department with a monthly $500 budget + +#### API + +Create a new tag and set `max_budget` and `budget_duration` + +```shell +curl -X POST 'http://0.0.0.0:4000/tag/new' \ + -H 'Authorization: Bearer sk-1234' \ + -H 'Content-Type: application/json' \ + -d '{ + "name": "engineering", + "description": "Engineering department cost center", + "max_budget": 500.0, + "budget_duration": "30d" + }' +``` + +**Request Body Parameters:** + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `name` | string | Yes | Unique name for the tag (e.g., cost center name) | +| `description` | string | No | Description of what this tag tracks | +| `models` | list[string] | No | Restrict tag to specific models | +| `max_budget` | float | No | Maximum budget in USD | +| `budget_duration` | string | No | How often budget resets (e.g., "30d", "1d") | +| `soft_budget` | float | No | Soft budget limit for warnings | + +**Response:** + +```json +{ + "name": "engineering", + "description": "Engineering department cost center", + "max_budget": 500.0, + "budget_duration": "30d", + "budget_reset_at": "2025-11-10T00:00:00Z", + "created_at": "2025-10-11T00:00:00Z" +} +``` + +#### LiteLLM Admin UI + +Navigate to the **Tag Management** page and click **Create New Tag**. Fill in the tag details and set your budget: + + + +
+ + +**Possible values for `budget_duration`:** + +| `budget_duration` | When Budget will reset | +| --- | --- | +| `budget_duration="1s"` | every 1 second | +| `budget_duration="1m"` | every 1 minute | +| `budget_duration="1h"` | every 1 hour | +| `budget_duration="1d"` | every 1 day | +| `budget_duration="7d"` | every 1 week | +| `budget_duration="30d"` | every 1 month | + +### 2. Use the tag in your requests + +Add tags to your API requests in the `metadata` field: + +:::info Tags Budgets on API Keys + +Currently, tag budget enforcement is only supported per request. If you'd like to set tags on API keys so all requests automatically inherit the tags budgets, please [create a feature request on GitHub](https://github.com/BerriAI/litellm/issues/new?assignees=&labels=enhancement&projects=&template=feature_request.yml&title=%5BFeat%5D%3A). + +::: + + + + + +```python +import openai + +client = openai.OpenAI( + api_key="sk-1234", # Your LiteLLM proxy key + base_url="http://0.0.0.0:4000" +) + +response = client.chat.completions.create( + model="gpt-4", + messages=[{"role": "user", "content": "Hello"}], + extra_body={ + "metadata": { + "tags": ["engineering"] + } + } +) +``` + + + + + +```shell +curl -X POST 'http://0.0.0.0:4000/chat/completions' \ + -H 'Authorization: Bearer sk-1234' \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "gpt-4", + "messages": [{"role": "user", "content": "Hello"}], + "metadata": { + "tags": ["engineering"] + } + }' +``` + + + + + +### 3. Test It + +Make requests until the budget is exceeded: + +```shell +curl -X POST 'http://0.0.0.0:4000/chat/completions' \ + -H 'Authorization: Bearer sk-1234' \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "gpt-4", + "messages": [{"role": "user", "content": "Hello"}], + "metadata": { + "tags": ["engineering"] + } + }' +``` + +**When budget is exceeded, you'll see:** + +```json +{ + "error": { + "message": "Budget has been exceeded! Tag=engineering Current cost: 505.50, Max budget: 500.0", + "type": "budget_exceeded", + "param": null, + "code": "400" + } +} +``` + +## Managing Tags + +### View Tag Information + +Get information about specific tags: + +```shell +curl -X POST 'http://0.0.0.0:4000/tag/info' \ + -H 'Authorization: Bearer sk-1234' \ + -H 'Content-Type: application/json' \ + -d '{ + "names": ["engineering", "marketing"] + }' +``` + +**Response:** + +```json +{ + "engineering": { + "name": "engineering", + "description": "Engineering department cost center", + "spend": 245.50, + "max_budget": 500.0, + "budget_duration": "30d", + "budget_reset_at": "2025-11-10T00:00:00Z", + "created_at": "2025-10-11T00:00:00Z", + "updated_at": "2025-10-11T12:30:00Z" + }, + "marketing": { + "name": "marketing", + "description": "Marketing department cost center", + "spend": 89.20, + "max_budget": 300.0, + "budget_duration": "30d", + "budget_reset_at": "2025-11-10T00:00:00Z", + "created_at": "2025-10-11T00:00:00Z", + "updated_at": "2025-10-11T12:30:00Z" + } +} +``` + +### Update Tag Budget + +Update an existing tag's budget: + +```shell +curl -X POST 'http://0.0.0.0:4000/tag/update' \ + -H 'Authorization: Bearer sk-1234' \ + -H 'Content-Type: application/json' \ + -d '{ + "name": "engineering", + "max_budget": 750.0, + "budget_duration": "30d" + }' +``` + +### Delete Tag + +```shell +curl -X POST 'http://0.0.0.0:4000/tag/delete' \ + -H 'Authorization: Bearer sk-1234' \ + -H 'Content-Type: application/json' \ + -d '{ + "name": "engineering" + }' +``` + +## Multiple Tags per Request + +You can apply multiple tags to a single request to track costs across different dimensions simultaneously. For example, track both the cost center and the specific project: + +```python +response = client.chat.completions.create( + model="gpt-4", + messages=[{"role": "user", "content": "Hello"}], + extra_body={ + "metadata": { + "tags": ["engineering", "project-alpha", "customer-acme"] + } + } +) +``` + +```shell +curl -X POST 'http://0.0.0.0:4000/chat/completions' \ + -H 'Authorization: Bearer sk-1234' \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "gpt-4", + "messages": [{"role": "user", "content": "Hello"}], + "metadata": { + "tags": ["engineering", "project-alpha", "customer-acme"] + } + }' +``` + +**Budget Enforcement:** If any tag exceeds its budget, the request will be rejected. diff --git a/docs/my-website/img/tag_budget1.png b/docs/my-website/img/tag_budget1.png new file mode 100644 index 000000000..061e406f4 Binary files /dev/null and b/docs/my-website/img/tag_budget1.png differ diff --git a/docs/my-website/img/tag_budget2.png b/docs/my-website/img/tag_budget2.png new file mode 100644 index 000000000..f44fd79dd Binary files /dev/null and b/docs/my-website/img/tag_budget2.png differ diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 00db20b8c..6d5d6d68d 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -189,12 +189,13 @@ const sidebars = { type: "category", label: "Budgets + Rate Limits", items: [ + "proxy/users", + "proxy/team_budgets", + "proxy/tag_budgets", "proxy/customers", "proxy/dynamic_rate_limit", "proxy/rate_limit_tiers", - "proxy/team_budgets", "proxy/temporary_budget_increase", - "proxy/users" ], }, "proxy/caching", diff --git a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma index 625897c3c..a13af1afc 100644 --- a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma +++ b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma @@ -25,6 +25,7 @@ model LiteLLM_BudgetTable { organization LiteLLM_OrganizationTable[] // multiple orgs can have the same budget keys LiteLLM_VerificationToken[] // multiple keys can have the same budget end_users LiteLLM_EndUserTable[] // multiple end-users can have the same budget + tags LiteLLM_TagTable[] // multiple tags can have the same budget team_membership LiteLLM_TeamMembership[] // budgets of Users within a Team organization_membership LiteLLM_OrganizationMembership[] // budgets of Users within a Organization } @@ -245,6 +246,20 @@ model LiteLLM_EndUserTable { blocked Boolean @default(false) } +// Track tags with budgets and spend +model LiteLLM_TagTable { + tag_name String @id + description String? + models String[] + model_info Json? // maps model_id to model_name + spend Float @default(0.0) + budget_id String? + litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id]) + created_at DateTime @default(now()) @map("created_at") + created_by String? + updated_at DateTime @default(now()) @updatedAt @map("updated_at") +} + // store proxy config.yaml model LiteLLM_Config { param_name String @id diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 096bbc029..e2d68df29 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -185,6 +185,7 @@ class Litellm_EntityType(enum.Enum): TEAM = "team" TEAM_MEMBER = "team_member" ORGANIZATION = "organization" + TAG = "tag" # global proxy level entity PROXY = "proxy" @@ -2143,6 +2144,30 @@ class LiteLLM_EndUserTable(LiteLLMPydanticObjectBase): model_config = ConfigDict(protected_namespaces=()) +class LiteLLM_TagTable(LiteLLMPydanticObjectBase): + tag_name: str + description: Optional[str] = None + models: List[str] = [] + model_info: Optional[dict] = None + spend: float = 0.0 + budget_id: Optional[str] = None + litellm_budget_table: Optional[LiteLLM_BudgetTable] = None + created_at: Optional[datetime] = None + created_by: Optional[str] = None + updated_at: Optional[datetime] = None + + @model_validator(mode="before") + @classmethod + def set_model_info(cls, values): + if values.get("spend") is None: + values.update({"spend": 0.0}) + if values.get("models") is None: + values.update({"models": []}) + return values + + model_config = ConfigDict(protected_namespaces=()) + + class LiteLLM_SpendLogs(LiteLLMPydanticObjectBase): request_id: str api_key: str @@ -3419,6 +3444,7 @@ class DBSpendUpdateTransactions(TypedDict): team_list_transactions: Optional[Dict[str, float]] team_member_list_transactions: Optional[Dict[str, float]] org_list_transactions: Optional[Dict[str, float]] + tag_list_transactions: Optional[Dict[str, float]] class SpendUpdateQueueItem(TypedDict, total=False): diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 68708b8fa..d95b7bd03 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -35,6 +35,7 @@ from litellm.proxy._types import ( LiteLLM_ObjectPermissionTable, LiteLLM_OrganizationMembershipTable, LiteLLM_OrganizationTable, + LiteLLM_TagTable, LiteLLM_TeamMembership, LiteLLM_TeamTable, LiteLLM_TeamTableCachedObj, @@ -99,6 +100,8 @@ async def common_checks( 10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks 11. [OPTIONAL] Vector store checks - is the object allowed to access the vector store """ + from litellm.proxy.proxy_server import prisma_client, user_api_key_cache + _model: Optional[Union[str, List[str]]] = get_model_from_request( request_body, route ) @@ -139,6 +142,14 @@ async def common_checks( valid_token=valid_token, ) + await _tag_max_budget_check( + request_body=request_body, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + valid_token=valid_token, + ) + # 4. If user is in budget ## 4.1 check personal budget, if personal key if ( @@ -499,6 +510,115 @@ async def get_end_user_object( return None +@log_db_metrics +async def get_tag_objects_batch( + tag_names: List[str], + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + parent_otel_span: Optional[Span] = None, + proxy_logging_obj: Optional[ProxyLogging] = None, +) -> Dict[str, LiteLLM_TagTable]: + """ + Batch fetch multiple tag objects from cache and db. + + Optimizes for latency by: + 1. Fetching all cached tags in parallel + 2. Batch fetching uncached tags in one DB query + + Args: + tag_names: List of tag names to fetch + prisma_client: Prisma database client + user_api_key_cache: Cache for storing tag objects + parent_otel_span: Optional OpenTelemetry span for tracing + proxy_logging_obj: Optional proxy logging object + + Returns: + Dictionary mapping tag_name to LiteLLM_TagTable object + """ + if prisma_client is None: + return {} + + if not tag_names: + return {} + + tag_objects = {} + uncached_tags = [] + + # Try to get all tags from cache first + for tag_name in tag_names: + cache_key = f"tag:{tag_name}" + cached_tag = await user_api_key_cache.async_get_cache(key=cache_key) + if cached_tag is not None: + if isinstance(cached_tag, dict): + tag_objects[tag_name] = LiteLLM_TagTable(**cached_tag) + else: + tag_objects[tag_name] = cached_tag + else: + uncached_tags.append(tag_name) + + # Batch fetch uncached tags from DB in one query + if uncached_tags: + try: + db_tags = await prisma_client.db.litellm_tagtable.find_many( + where={"tag_name": {"in": uncached_tags}}, + include={"litellm_budget_table": True}, + ) + + # Cache and add to tag_objects + for db_tag in db_tags: + tag_name = db_tag.tag_name + cache_key = f"tag:{tag_name}" + # Cache with default TTL (same as end_user objects) + await user_api_key_cache.async_set_cache( + key=cache_key, value=db_tag.dict() + ) + tag_objects[tag_name] = LiteLLM_TagTable(**db_tag.dict()) + except Exception as e: + verbose_proxy_logger.debug( + f"Error batch fetching tags from database: {e}" + ) + + return tag_objects + + +@log_db_metrics +async def get_tag_object( + tag_name: Optional[str], + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + parent_otel_span: Optional[Span] = None, + proxy_logging_obj: Optional[ProxyLogging] = None, +) -> Optional[LiteLLM_TagTable]: + """ + Returns tag object from cache or db. + + Uses default cache TTL (same as end_user objects) to avoid drift. + + Args: + tag_name: Name of the tag to fetch + prisma_client: Prisma database client + user_api_key_cache: Cache for storing tag objects + parent_otel_span: Optional OpenTelemetry span for tracing + proxy_logging_obj: Optional proxy logging object + + Returns: + LiteLLM_TagTable object if found, None otherwise + """ + if prisma_client is None or tag_name is None: + return None + + # Use batch helper for consistency + tag_objects = await get_tag_objects_batch( + tag_names=[tag_name], + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) + + return tag_objects.get(tag_name) + + @log_db_metrics async def get_team_membership( user_id: str, @@ -1642,6 +1762,60 @@ async def _team_max_budget_check( ) +async def _tag_max_budget_check( + request_body: dict, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + proxy_logging_obj: ProxyLogging, + valid_token: Optional[UserAPIKeyAuth], +): + """ + Check if any tags in the request are over their max budget. + + Raises: + BudgetExceededError if any tag is over its max budget. + Triggers a budget alert if any tag is over its max budget. + """ + from litellm.proxy.common_utils.http_parsing_utils import ( + get_tags_from_request_body, + ) + + if prisma_client is None: + return + + # Get tags from request metadata + tags = get_tags_from_request_body(request_body=request_body) + if not tags: + return + + # Batch fetch all tags in one go + tag_objects = await get_tag_objects_batch( + tag_names=tags, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + + # Check budget for each tag + for tag_name in tags: + tag_object = tag_objects.get(tag_name) + if tag_object is None: + continue + + # Check if tag has budget limits + if ( + tag_object.litellm_budget_table is not None + and tag_object.litellm_budget_table.max_budget is not None + and tag_object.spend is not None + and tag_object.spend > tag_object.litellm_budget_table.max_budget + ): + raise litellm.BudgetExceededError( + current_cost=tag_object.spend, + max_budget=tag_object.litellm_budget_table.max_budget, + message=f"Budget has been exceeded! Tag={tag_name} Current cost: {tag_object.spend}, Max budget: {tag_object.litellm_budget_table.max_budget}", + ) + + def is_model_allowed_by_pattern(model: str, allowed_model_pattern: str) -> bool: """ Check if a model matches an allowed pattern. diff --git a/litellm/proxy/common_utils/http_parsing_utils.py b/litellm/proxy/common_utils/http_parsing_utils.py index fd84eeda0..33d432e86 100644 --- a/litellm/proxy/common_utils/http_parsing_utils.py +++ b/litellm/proxy/common_utils/http_parsing_utils.py @@ -7,6 +7,9 @@ from fastapi import Request, UploadFile, status from litellm._logging import verbose_proxy_logger from litellm.proxy._types import ProxyException +from litellm.proxy.common_utils.callback_utils import ( + get_metadata_variable_name_from_kwargs, +) from litellm.types.router import Deployment @@ -245,4 +248,23 @@ async def get_request_body(request: Request) -> Dict[str, Any]: raise ValueError( f"Unsupported content type: {request.headers.get('content-type')}" ) - return {} \ No newline at end of file + return {} + + +def get_tags_from_request_body(request_body: dict) -> List[str]: + """ + Extract tags from request body metadata. + + Args: + request_body: The request body dictionary + + Returns: + List of tag names (strings), empty list if no valid tags found + """ + metadata_variable_name = get_metadata_variable_name_from_kwargs(request_body) + metadata = request_body.get(metadata_variable_name, {}) + tags_in_metadata: List[str] = metadata.get("tags", []) + tags_in_request_body: List[str] = request_body.get("tags", []) + combined_tags: List[str] = tags_in_metadata + tags_in_request_body + return [tag for tag in combined_tags if isinstance(tag, str)] + diff --git a/litellm/proxy/db/db_spend_update_writer.py b/litellm/proxy/db/db_spend_update_writer.py index 3277f19c7..4258ab41d 100644 --- a/litellm/proxy/db/db_spend_update_writer.py +++ b/litellm/proxy/db/db_spend_update_writer.py @@ -17,6 +17,7 @@ import litellm from litellm._logging import verbose_proxy_logger from litellm.caching import DualCache, RedisCache from litellm.constants import DB_SPEND_UPDATE_JOB_NAME +from litellm.litellm_core_utils.safe_json_loads import safe_json_loads from litellm.proxy._types import ( DB_CONNECTION_ERROR_TYPES, BaseDailySpendTransaction, @@ -147,6 +148,13 @@ class DBSpendUpdateWriter: prisma_client=prisma_client, ) ) + asyncio.create_task( + self._update_tag_db( + response_cost=response_cost, + request_tags=payload.get("request_tags"), + prisma_client=prisma_client, + ) + ) if disable_spend_logs is False: await self._insert_spend_log_to_db( @@ -325,6 +333,54 @@ class DBSpendUpdateWriter: ) raise e + async def _update_tag_db( + self, + response_cost: Optional[float], + request_tags: Optional[str], + prisma_client: Optional[PrismaClient], + ): + """ + Update spend for all tags in the request. + + Args: + response_cost: Cost of the request + request_tags: JSON string of tags list e.g. '["prod-tag", "test-tag"]' + prisma_client: Prisma client instance + """ + try: + if request_tags is None or prisma_client is None: + return + + # Parse tags from JSON string + tags = [] + if isinstance(request_tags, str): + tags = safe_json_loads(request_tags, default=[]) + if not tags: + verbose_proxy_logger.debug( + f"Failed to parse request_tags JSON: {request_tags}" + ) + return + elif isinstance(request_tags, list): + tags = request_tags + else: + return + + # Update spend for each tag + for tag_name in tags: + if tag_name and isinstance(tag_name, str): + await self.spend_update_queue.add_update( + update=SpendUpdateQueueItem( + entity_type=Litellm_EntityType.TAG, + entity_id=tag_name, + response_cost=response_cost, + ) + ) + except Exception as e: + verbose_proxy_logger.debug( + f"Update Tag DB failed to execute - {str(e)}\n{traceback.format_exc()}" + ) + raise e + async def _insert_spend_log_to_db( self, payload: Union[dict, SpendLogsPayload], @@ -775,6 +831,75 @@ class DBSpendUpdateWriter: e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj ) + ### UPDATE TAG TABLE ### + tag_list_transactions = db_spend_update_transactions["tag_list_transactions"] + await DBSpendUpdateWriter._update_entity_spend_in_db( + entity_name="Tag", + transactions=tag_list_transactions, + table_accessor="litellm_tagtable", + where_field="tag_name", + n_retry_times=n_retry_times, + prisma_client=prisma_client, + proxy_logging_obj=proxy_logging_obj, + ) + + @staticmethod + async def _update_entity_spend_in_db( + entity_name: str, + transactions: Optional[Dict[str, float]], + table_accessor: Any, + where_field: str, + n_retry_times: int, + prisma_client: PrismaClient, + proxy_logging_obj: ProxyLogging, + ): + """ + Helper function to update spend for any entity type (team, org, tag, etc). + + Args: + entity_name: Name of entity for logging (e.g., "Team", "Org", "Tag") + transactions: Dictionary of {entity_id: response_cost} + table_accessor: Prisma table accessor (e.g., prisma_client.db.litellm_teamtable) + where_field: Field name for where clause (e.g., "team_id", "organization_id", "tag_name") + n_retry_times: Number of retries on failure + prisma_client: Prisma client instance + proxy_logging_obj: Proxy logging object + """ + from litellm.proxy.utils import _raise_failed_update_spend_exception + + verbose_proxy_logger.debug( + f"{entity_name} Spend transactions: {transactions}" + ) + if transactions is not None and len(transactions.keys()) > 0: + for i in range(n_retry_times + 1): + start_time = time.time() + try: + async with prisma_client.db.tx( + timeout=timedelta(seconds=60) + ) as transaction: + async with transaction.batch_() as batcher: + for entity_id, response_cost in transactions.items(): + verbose_proxy_logger.debug( + f"Updating spend for {entity_name} {where_field}={entity_id} by {response_cost}" + ) + getattr(batcher, table_accessor).update_many( + where={where_field: entity_id}, + data={"spend": {"increment": response_cost}}, + ) + break + except DB_CONNECTION_ERROR_TYPES as e: + if i >= n_retry_times: + _raise_failed_update_spend_exception( + e=e, + start_time=start_time, + proxy_logging_obj=proxy_logging_obj, + ) + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) + # fmt: off @overload diff --git a/litellm/proxy/db/db_transaction_queue/redis_update_buffer.py b/litellm/proxy/db/db_transaction_queue/redis_update_buffer.py index b08a8517e..91d0bee1d 100644 --- a/litellm/proxy/db/db_transaction_queue/redis_update_buffer.py +++ b/litellm/proxy/db/db_transaction_queue/redis_update_buffer.py @@ -380,6 +380,7 @@ class RedisUpdateBuffer: team_list_transactions={}, team_member_list_transactions={}, org_list_transactions={}, + tag_list_transactions={}, ) # Define the transaction fields to process @@ -390,6 +391,7 @@ class RedisUpdateBuffer: "team_list_transactions", "team_member_list_transactions", "org_list_transactions", + "tag_list_transactions", ] # Loop through each transaction and combine the values diff --git a/litellm/proxy/db/db_transaction_queue/spend_update_queue.py b/litellm/proxy/db/db_transaction_queue/spend_update_queue.py index 98a9e5088..9b0449bb9 100644 --- a/litellm/proxy/db/db_transaction_queue/spend_update_queue.py +++ b/litellm/proxy/db/db_transaction_queue/spend_update_queue.py @@ -135,6 +135,7 @@ class SpendUpdateQueue(BaseUpdateQueue): team_list_transactions={}, team_member_list_transactions={}, org_list_transactions={}, + tag_list_transactions={}, ) # Map entity types to their corresponding transaction dictionary keys @@ -145,6 +146,7 @@ class SpendUpdateQueue(BaseUpdateQueue): Litellm_EntityType.TEAM: "team_list_transactions", Litellm_EntityType.TEAM_MEMBER: "team_member_list_transactions", Litellm_EntityType.ORGANIZATION: "org_list_transactions", + Litellm_EntityType.TAG: "tag_list_transactions", } for update in updates: @@ -192,6 +194,10 @@ class SpendUpdateQueue(BaseUpdateQueue): transactions_dict = db_spend_update_transactions[ "org_list_transactions" ] + elif dict_key == "tag_list_transactions": + transactions_dict = db_spend_update_transactions[ + "tag_list_transactions" + ] else: continue diff --git a/litellm/proxy/hooks/proxy_track_cost_callback.py b/litellm/proxy/hooks/proxy_track_cost_callback.py index 08e6540f7..17b91db69 100644 --- a/litellm/proxy/hooks/proxy_track_cost_callback.py +++ b/litellm/proxy/hooks/proxy_track_cost_callback.py @@ -1,7 +1,7 @@ import asyncio import traceback from datetime import datetime -from typing import Any, Optional, Union, cast +from typing import Any, List, Optional, Union, cast import litellm from litellm._logging import verbose_proxy_logger @@ -131,6 +131,11 @@ class _ProxyDBLogger(CustomLogger): if sl_object is not None else kwargs.get("response_cost", None) ) + tags: Optional[List[str]] = ( + sl_object.get("request_tags", None) + if sl_object is not None + else None + ) if response_cost is not None: user_api_key = metadata.get("user_api_key", None) @@ -172,6 +177,7 @@ class _ProxyDBLogger(CustomLogger): response_cost=response_cost, team_id=team_id, parent_otel_span=parent_otel_span, + tags=tags, ) ) diff --git a/litellm/proxy/management_endpoints/organization_endpoints.py b/litellm/proxy/management_endpoints/organization_endpoints.py index 2b52b6fc5..a8cbf6209 100644 --- a/litellm/proxy/management_endpoints/organization_endpoints.py +++ b/litellm/proxy/management_endpoints/organization_endpoints.py @@ -11,12 +11,12 @@ Endpoints for /organization operations #### ORGANIZATION MANAGEMENT #### -from litellm._uuid import uuid from typing import List, Optional, Tuple from fastapi import APIRouter, Depends, HTTPException, Request, status from litellm._logging import verbose_proxy_logger +from litellm._uuid import uuid from litellm.proxy._types import * from litellm.proxy.auth.auth_checks import can_user_call_model from litellm.proxy.auth.user_api_key_auth import user_api_key_auth @@ -29,6 +29,7 @@ from litellm.proxy.management_helpers.object_permission_utils import ( ) from litellm.proxy.management_helpers.utils import ( get_new_internal_user_defaults, + handle_budget_for_entity, management_endpoint_wrapper, ) from litellm.proxy.utils import PrismaClient @@ -168,30 +169,17 @@ async def new_organization( except Exception: pass - if data.budget_id is None: - """ - Every organization needs a budget attached. - - If none provided, create one based on provided values - """ - budget_params = LiteLLM_BudgetTable.model_fields.keys() - - # Only include Budget Params when creating an entry in litellm_budgettable - _json_data = data.json(exclude_none=True) - _budget_data = {k: v for k, v in _json_data.items() if k in budget_params} - budget_row = LiteLLM_BudgetTable(**_budget_data) - - new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True)) - - _budget = await prisma_client.db.litellm_budgettable.create( - data={ - **new_budget, # type: ignore - "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, - "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, - } - ) # type: ignore - - data.budget_id = _budget.budget_id + # Handle budget creation/assignment using common helper + budget_id = await handle_budget_for_entity( + data=data, + existing_budget_id=None, + user_api_key_dict=user_api_key_dict, + prisma_client=prisma_client, + litellm_proxy_admin_name=litellm_proxy_admin_name, + ) + + if budget_id is not None: + data.budget_id = budget_id ## Handle Object Permission - MCP, Vector Stores etc. object_permission_id = await _set_object_permission( @@ -322,20 +310,22 @@ async def update_organization( existing_organization_row=existing_organization_row, ) - # Handle budget updates if budget fields are provided - budget_fields = {k: v for k, v in data.model_dump().items() - if k in LiteLLM_BudgetTable.model_fields.keys() and v is not None} + from litellm.proxy.proxy_server import litellm_proxy_admin_name + + # Handle budget updates using common helper + budget_id = await handle_budget_for_entity( + data=data, + existing_budget_id=existing_organization_row.budget_id, + user_api_key_dict=user_api_key_dict, + prisma_client=prisma_client, + litellm_proxy_admin_name=litellm_proxy_admin_name, + ) - if budget_fields and existing_organization_row.budget_id: - await update_budget( - budget_obj=BudgetNewRequest( - budget_id=existing_organization_row.budget_id, - **budget_fields - ), - user_api_key_dict=user_api_key_dict, - ) + # Update budget_id if it changed + if budget_id != existing_organization_row.budget_id: + updated_organization_row["budget_id"] = budget_id - # Remove budget fields from organization update data + # Remove budget fields from organization update data (they're handled via budget table) for field in LiteLLM_BudgetTable.model_fields.keys(): updated_organization_row.pop(field, None) diff --git a/litellm/proxy/management_endpoints/tag_management_endpoints.py b/litellm/proxy/management_endpoints/tag_management_endpoints.py index 0bd7b3eb8..1366c2ef4 100644 --- a/litellm/proxy/management_endpoints/tag_management_endpoints.py +++ b/litellm/proxy/management_endpoints/tag_management_endpoints.py @@ -11,7 +11,6 @@ All /tag management endpoints """ import asyncio -import datetime import json from typing import TYPE_CHECKING, Dict, List, Optional @@ -25,6 +24,7 @@ from litellm.proxy.management_endpoints.common_daily_activity import ( SpendAnalyticsPaginatedResponse, get_daily_activity, ) +from litellm.proxy.management_helpers.utils import handle_budget_for_entity from litellm.types.tag_management import ( LiteLLM_DailyTagSpendTable, TagConfig, @@ -53,69 +53,6 @@ async def _get_model_names(prisma_client, model_ids: list) -> Dict[str, str]: return {} -async def _get_tags_config(prisma_client) -> Dict[str, TagConfig]: - """Helper function to get tags config from db""" - try: - tags_config = await prisma_client.db.litellm_config.find_unique( - where={"param_name": "tags_config"} - ) - if tags_config is None: - return {} - # Convert from JSON if needed - if isinstance(tags_config.param_value, str): - config_dict = json.loads(tags_config.param_value) - else: - config_dict = tags_config.param_value or {} - - # For each tag, get the model names - for tag_name, tag_config in config_dict.items(): - if isinstance(tag_config, dict) and tag_config.get("models"): - model_info = await _get_model_names(prisma_client, tag_config["models"]) - tag_config["model_info"] = model_info - - return config_dict - except Exception: - return {} - - -async def _save_tags_config(prisma_client, tags_config: Dict[str, TagConfig]): - """Helper function to save tags config to db""" - try: - verbose_proxy_logger.debug(f"Saving tags config: {tags_config}") - # Convert TagConfig objects to dictionaries - tags_config_dict = {} - for name, tag in tags_config.items(): - if isinstance(tag, TagConfig): - tag_dict = tag.model_dump() - # Remove model_info before saving as it will be dynamically generated - if "model_info" in tag_dict: - del tag_dict["model_info"] - tags_config_dict[name] = tag_dict - else: - # If it's already a dict, remove model_info - tag_copy = tag.copy() - if "model_info" in tag_copy: - del tag_copy["model_info"] - tags_config_dict[name] = tag_copy - - json_tags_config = json.dumps(tags_config_dict, default=str) - verbose_proxy_logger.debug(f"JSON tags config: {json_tags_config}") - await prisma_client.db.litellm_config.upsert( - where={"param_name": "tags_config"}, - data={ - "create": { - "param_name": "tags_config", - "param_value": json_tags_config, - }, - "update": {"param_value": json_tags_config}, - }, - ) - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Error saving tags config: {str(e)}" - ) - - async def get_deployments_by_model( model: str, llm_router: "Router" ) -> List["Deployment"]: @@ -159,9 +96,23 @@ async def new_tag( - name: str - The name of the tag - description: Optional[str] - Description of what this tag represents - models: List[str] - List of either 'model_id' or 'model_name' allowed for this tag + - budget_id: Optional[str] - The id for a budget (tpm/rpm/max budget) for the tag + + ### IF NO BUDGET ID - CREATE ONE WITH THESE PARAMS ### + - max_budget: Optional[float] - Max budget for tag + - tpm_limit: Optional[int] - Max tpm limit for tag + - rpm_limit: Optional[int] - Max rpm limit for tag + - max_parallel_requests: Optional[int] - Max parallel requests for tag + - soft_budget: Optional[float] - Get a slack alert when this soft budget is reached + - model_max_budget: Optional[dict] - Max budget for a specific model + - budget_duration: Optional[str] - Frequency of resetting tag budget """ from litellm.proxy._types import CommonProxyErrors - from litellm.proxy.proxy_server import llm_router, prisma_client + from litellm.proxy.proxy_server import ( + litellm_proxy_admin_name, + llm_router, + prisma_client, + ) if prisma_client is None: raise HTTPException( @@ -172,29 +123,38 @@ async def new_tag( status_code=500, detail=CommonProxyErrors.no_llm_router.value ) try: - # Get existing tags config - tags_config = await _get_tags_config(prisma_client) - # Check if tag already exists - if tag.name in tags_config: + existing_tag = await prisma_client.db.litellm_tagtable.find_unique( + where={"tag_name": tag.name} + ) + if existing_tag is not None: raise HTTPException( status_code=400, detail=f"Tag {tag.name} already exists" ) - # Add new tag - tags_config[tag.name] = TagConfig( - name=tag.name, - description=tag.description, - models=tag.models, - created_at=str(datetime.datetime.now()), - updated_at=str(datetime.datetime.now()), - created_by=user_api_key_dict.user_id, + # Handle budget creation/assignment using common helper + budget_id = await handle_budget_for_entity( + data=tag, + existing_budget_id=None, + user_api_key_dict=user_api_key_dict, + prisma_client=prisma_client, + litellm_proxy_admin_name=litellm_proxy_admin_name, ) - # Save updated config - await _save_tags_config( - prisma_client=prisma_client, - tags_config=tags_config, + # Get model names for model_info + model_info = await _get_model_names(prisma_client, tag.models or []) + + # Create new tag in database + new_tag_record = await prisma_client.db.litellm_tagtable.create( + data={ + "tag_name": tag.name, + "description": tag.description, + "models": tag.models or [], + "model_info": json.dumps(model_info), + "spend": 0.0, + "budget_id": budget_id, + "created_by": user_api_key_dict.user_id, + } ) # Update models with new tag @@ -213,13 +173,20 @@ async def new_tag( ) await asyncio.gather(*tasks) - # Get model names for response - model_info = await _get_model_names(prisma_client, tag.models or []) - tags_config[tag.name].model_info = model_info + # Build response + tag_config = TagConfig( + name=new_tag_record.tag_name, + description=new_tag_record.description, + models=new_tag_record.models, + model_info=model_info, + created_at=new_tag_record.created_at.isoformat(), + updated_at=new_tag_record.updated_at.isoformat(), + created_by=new_tag_record.created_by, + ) return { "message": f"Tag {tag.name} created successfully", - "tag": tags_config[tag.name], + "tag": tag_config, } except Exception as e: verbose_proxy_logger.exception(f"Error creating tag: {str(e)}") @@ -264,6 +231,16 @@ async def update_tag( - name: str - The name of the tag to update - description: Optional[str] - Updated description - models: List[str] - Updated list of allowed LLM models + - budget_id: Optional[str] - The id for a budget to associate with the tag + + ### BUDGET UPDATE PARAMS ### + - max_budget: Optional[float] - Max budget for tag + - tpm_limit: Optional[int] - Max tpm limit for tag + - rpm_limit: Optional[int] - Max rpm limit for tag + - max_parallel_requests: Optional[int] - Max parallel requests for tag + - soft_budget: Optional[float] - Get a slack alert when this soft budget is reached + - model_max_budget: Optional[dict] - Max budget for a specific model + - budget_duration: Optional[str] - Frequency of resetting tag budget """ from litellm.proxy.proxy_server import prisma_client @@ -271,35 +248,58 @@ async def update_tag( raise HTTPException(status_code=500, detail="Database not connected") try: - # Get existing tags config - tags_config = await _get_tags_config(prisma_client) - # Check if tag exists - if tag.name not in tags_config: + existing_tag = await prisma_client.db.litellm_tagtable.find_unique( + where={"tag_name": tag.name} + ) + if existing_tag is None: raise HTTPException(status_code=404, detail=f"Tag {tag.name} not found") - # Update tag - tag_config_dict = dict(tags_config[tag.name]) - tag_config_dict.update( - { - "description": tag.description, - "models": tag.models, - "updated_at": str(datetime.datetime.now()), - "updated_by": user_api_key_dict.user_id, - } + from litellm.proxy.proxy_server import litellm_proxy_admin_name + + # Handle budget updates using common helper + budget_id = await handle_budget_for_entity( + data=tag, + existing_budget_id=existing_tag.budget_id, + user_api_key_dict=user_api_key_dict, + prisma_client=prisma_client, + litellm_proxy_admin_name=litellm_proxy_admin_name, ) - tags_config[tag.name] = TagConfig(**tag_config_dict) - # Save updated config - await _save_tags_config(prisma_client, tags_config) - - # Get model names for response + # Get model names for model_info model_info = await _get_model_names(prisma_client, tag.models or []) - tags_config[tag.name].model_info = model_info + + # Prepare update data + update_data = { + "description": tag.description, + "models": tag.models or [], + "model_info": json.dumps(model_info), + } + + # Add budget_id if it changed + if budget_id != existing_tag.budget_id: + update_data["budget_id"] = budget_id + + # Update tag in database + updated_tag_record = await prisma_client.db.litellm_tagtable.update( + where={"tag_name": tag.name}, + data=update_data, + ) + + # Build response + tag_config = TagConfig( + name=updated_tag_record.tag_name, + description=updated_tag_record.description, + models=updated_tag_record.models, + model_info=model_info, + created_at=updated_tag_record.created_at.isoformat(), + updated_at=updated_tag_record.updated_at.isoformat(), + created_by=updated_tag_record.created_by, + ) return { "message": f"Tag {tag.name} updated successfully", - "tag": tags_config[tag.name], + "tag": tag_config, } except Exception as e: verbose_proxy_logger.exception(f"Error updating tag: {str(e)}") @@ -327,18 +327,47 @@ async def info_tag( raise HTTPException(status_code=500, detail="Database not connected") try: - tags_config = await _get_tags_config(prisma_client) - - # Filter tags based on requested names - requested_tags = {name: tags_config.get(name) for name in data.names} + # Query tags from database with budget info + tag_records = await prisma_client.db.litellm_tagtable.find_many( + where={"tag_name": {"in": data.names}}, + include={"litellm_budget_table": True}, + ) # Check if any requested tags don't exist - missing_tags = [name for name in data.names if name not in tags_config] + found_tag_names = {tag.tag_name for tag in tag_records} + missing_tags = [name for name in data.names if name not in found_tag_names] if missing_tags: raise HTTPException( status_code=404, detail=f"Tags not found: {missing_tags}" ) + # Build response + requested_tags = {} + for tag_record in tag_records: + # Parse model_info from JSON + model_info = {} + if tag_record.model_info: + if isinstance(tag_record.model_info, str): + model_info = json.loads(tag_record.model_info) + else: + model_info = tag_record.model_info + + tag_dict = { + "name": tag_record.tag_name, + "description": tag_record.description, + "models": tag_record.models, + "model_info": model_info, + "created_at": tag_record.created_at.isoformat(), + "updated_at": tag_record.updated_at.isoformat(), + "created_by": tag_record.created_by, + } + + # Add budget info if available + if hasattr(tag_record, "litellm_budget_table") and tag_record.litellm_budget_table: + tag_dict["litellm_budget_table"] = tag_record.litellm_budget_table + + requested_tags[tag_record.tag_name] = tag_dict + return requested_tags except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -348,13 +377,12 @@ async def info_tag( "/tag/list", tags=["tag management"], dependencies=[Depends(user_api_key_auth)], - response_model=List[TagConfig], ) async def list_tags( user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ - List all available tags. + List all available tags with their budget information. """ from litellm.proxy.proxy_server import prisma_client @@ -363,8 +391,37 @@ async def list_tags( try: ## QUERY STORED TAGS ## - tags_config = await _get_tags_config(prisma_client) - list_of_tags = list(tags_config.values()) + tag_records = await prisma_client.db.litellm_tagtable.find_many( + include={"litellm_budget_table": True} + ) + + stored_tag_names = set() + list_of_tags = [] + for tag_record in tag_records: + stored_tag_names.add(tag_record.tag_name) + # Parse model_info from JSON + model_info = {} + if tag_record.model_info: + if isinstance(tag_record.model_info, str): + model_info = json.loads(tag_record.model_info) + else: + model_info = tag_record.model_info + + tag_dict = { + "name": tag_record.tag_name, + "description": tag_record.description, + "models": tag_record.models, + "model_info": model_info, + "created_at": tag_record.created_at.isoformat(), + "updated_at": tag_record.updated_at.isoformat(), + "created_by": tag_record.created_by, + } + + # Add budget info if available + if hasattr(tag_record, "litellm_budget_table") and tag_record.litellm_budget_table: + tag_dict["litellm_budget_table"] = tag_record.litellm_budget_table + + list_of_tags.append(tag_dict) ## QUERY DYNAMIC TAGS ## dynamic_tags = await prisma_client.db.litellm_dailytagspend.find_many( @@ -377,15 +434,15 @@ async def list_tags( ] dynamic_tag_config = [ - TagConfig( - name=tag.tag, - description="This is just a spend tag that was passed dynamically in a request. It does not control any LLM models.", - models=None, - created_at=tag.created_at.isoformat(), - updated_at=tag.updated_at.isoformat(), - ) + { + "name": tag.tag, + "description": "This is just a spend tag that was passed dynamically in a request. It does not control any LLM models.", + "models": None, + "created_at": tag.created_at.isoformat(), + "updated_at": tag.updated_at.isoformat(), + } for tag in dynamic_tags_list - if tag.tag not in tags_config + if tag.tag not in stored_tag_names ] return list_of_tags + dynamic_tag_config @@ -414,18 +471,15 @@ async def delete_tag( raise HTTPException(status_code=500, detail="Database not connected") try: - # Get existing tags config - tags_config = await _get_tags_config(prisma_client) - # Check if tag exists - if data.name not in tags_config: + existing_tag = await prisma_client.db.litellm_tagtable.find_unique( + where={"tag_name": data.name} + ) + if existing_tag is None: raise HTTPException(status_code=404, detail=f"Tag {data.name} not found") - # Delete tag - del tags_config[data.name] - - # Save updated config - await _save_tags_config(prisma_client, tags_config) + # Delete tag from database + await prisma_client.db.litellm_tagtable.delete(where={"tag_name": data.name}) return {"message": f"Tag {data.name} deleted successfully"} except Exception as e: diff --git a/litellm/proxy/management_helpers/utils.py b/litellm/proxy/management_helpers/utils.py index 3eab44441..67a1ea659 100644 --- a/litellm/proxy/management_helpers/utils.py +++ b/litellm/proxy/management_helpers/utils.py @@ -10,10 +10,12 @@ import litellm from litellm._logging import verbose_logger from litellm._uuid import uuid from litellm.proxy._types import ( # key request types; user request types; team request types; customer request types + BudgetNewRequest, DeleteCustomerRequest, DeleteTeamRequest, DeleteUserRequest, KeyRequest, + LiteLLM_BudgetTable, LiteLLM_TeamMembership, LiteLLM_UserTable, ManagementEndpointLoggingPayload, @@ -53,6 +55,89 @@ def get_new_internal_user_defaults( return non_null_dict +async def handle_budget_for_entity( + data, + existing_budget_id: Optional[str], + user_api_key_dict: UserAPIKeyAuth, + prisma_client: PrismaClient, + litellm_proxy_admin_name: str, +) -> Optional[str]: + """ + Common helper to handle budget creation/updates for entities (organizations, tags, etc). + + This function: + 1. Creates a new budget if budget_id is None but budget fields are provided + 2. Updates an existing budget if budget fields are provided and budget_id exists + 3. Returns the budget_id to use (existing or newly created) + + Args: + data: The request object (e.g., TagNewRequest, NewOrganizationRequest, etc.) containing budget fields + existing_budget_id: The existing budget_id if updating an entity, None if creating new + user_api_key_dict: User authentication info + prisma_client: Database client + litellm_proxy_admin_name: Admin name for audit trail + + Returns: + Optional[str]: The budget_id to use, or None if no budget was created/updated + """ + from litellm.proxy.management_endpoints.budget_management_endpoints import ( + update_budget, + ) + + # Get all budget field names + budget_params = LiteLLM_BudgetTable.model_fields.keys() + + # Extract budget fields from data + _json_data = data.model_dump(exclude_none=True) if hasattr(data, "model_dump") else data + _budget_data = {k: v for k, v in _json_data.items() if k in budget_params} + + # Check if budget_id is explicitly provided in the data + data_budget_id = getattr(data, "budget_id", None) + + # Case 1: Creating new entity - no existing budget_id + if existing_budget_id is None: + if data_budget_id is not None: + # Use the provided budget_id + return data_budget_id + elif _budget_data: + # Create a new budget with the provided fields + budget_row = LiteLLM_BudgetTable(**_budget_data) + new_budget_data = prisma_client.jsonify_object( + budget_row.model_dump(exclude_none=True) + ) + + _budget = await prisma_client.db.litellm_budgettable.create( + data={ + **new_budget_data, # type: ignore + "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + } + ) # type: ignore + + return _budget.budget_id + else: + # No budget fields provided, no budget to create + return None + + # Case 2: Updating existing entity - has existing budget_id + else: + # If budget fields are provided, update the existing budget + if _budget_data: + await update_budget( + budget_obj=BudgetNewRequest( + budget_id=existing_budget_id, **_budget_data + ), + user_api_key_dict=user_api_key_dict, + ) + + # If a different budget_id is explicitly provided, use that instead + if data_budget_id is not None and data_budget_id != existing_budget_id: + return data_budget_id + + # Otherwise, keep using the existing budget_id + return existing_budget_id + + async def add_new_member( new_member: Member, max_budget_in_team: Optional[float], diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index a4e759343..be0f7d6be 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -33,9 +33,9 @@ from litellm.constants import ( AIOHTTP_TTL_DNS_CACHE, BASE_MCP_ROUTE, DEFAULT_MAX_RECURSE_DEPTH, - DEFAULT_SLACK_ALERTING_THRESHOLD, - DEFAULT_SHARED_HEALTH_CHECK_TTL, DEFAULT_SHARED_HEALTH_CHECK_LOCK_TTL, + DEFAULT_SHARED_HEALTH_CHECK_TTL, + DEFAULT_SLACK_ALERTING_THRESHOLD, LITELLM_EMBEDDING_PROVIDERS_SUPPORTING_INPUT_ARRAY_OF_TOKENS, LITELLM_SETTINGS_SAFE_DB_OVERRIDES, ) @@ -1157,6 +1157,7 @@ async def update_cache( # noqa: PLR0915 team_id: Optional[str], response_cost: Optional[float], parent_otel_span: Optional[Span], # type: ignore + tags: Optional[List[str]] = None, ): """ Use this to update the cache with new user spend. @@ -1377,6 +1378,50 @@ async def update_cache( # noqa: PLR0915 f"An error occurred updating end user cache: {str(e)}" ) + ### UPDATE TAG SPEND ### + async def _update_tag_cache(): + """ + Update the tag cache with the new spend. + """ + if tags is None or response_cost is None: + return + + try: + for tag_name in tags: + if not tag_name or not isinstance(tag_name, str): + continue + + cache_key = f"tag:{tag_name}" + # Fetch the existing tag object from cache + existing_tag_obj = await user_api_key_cache.async_get_cache(key=cache_key) + if existing_tag_obj is None: + # do nothing if tag not in api key cache + continue + + verbose_proxy_logger.debug( + f"_update_tag_cache: existing spend for tag={tag_name}: {existing_tag_obj}; response_cost: {response_cost}" + ) + + if isinstance(existing_tag_obj, dict): + existing_spend = existing_tag_obj.get("spend", 0) or 0 + else: + existing_spend = getattr(existing_tag_obj, "spend", 0) or 0 + + # Calculate the new cost by adding the existing cost and response_cost + new_spend = existing_spend + response_cost + + # Update the spend column for the given tag + if isinstance(existing_tag_obj, dict): + existing_tag_obj["spend"] = new_spend + values_to_update_in_cache.append((cache_key, existing_tag_obj)) + else: + existing_tag_obj.spend = new_spend + values_to_update_in_cache.append((cache_key, existing_tag_obj)) + except Exception as e: + verbose_proxy_logger.exception( + f"An error occurred updating tag cache: {str(e)}" + ) + if token is not None and response_cost is not None: await _update_key_cache(token=token, response_cost=response_cost) @@ -1389,6 +1434,9 @@ async def update_cache( # noqa: PLR0915 if team_id is not None: await _update_team_cache() + if tags is not None: + await _update_tag_cache() + asyncio.create_task( user_api_key_cache.async_set_cache_pipeline( cache_list=values_to_update_in_cache, @@ -1431,7 +1479,9 @@ async def _run_background_health_check(): # Initialize shared health check manager if Redis is available and feature is enabled shared_health_manager = None if use_shared_health_check and redis_usage_cache is not None: - from litellm.proxy.health_check_utils.shared_health_check_manager import SharedHealthCheckManager + from litellm.proxy.health_check_utils.shared_health_check_manager import ( + SharedHealthCheckManager, + ) shared_health_manager = SharedHealthCheckManager( redis_cache=redis_usage_cache, health_check_ttl=DEFAULT_SHARED_HEALTH_CHECK_TTL, @@ -1451,10 +1501,13 @@ async def _run_background_health_check(): ] # Use shared health check if available, otherwise fall back to direct health check + # Convert health_check_details to bool for perform_shared_health_check (defaults to True if None) + details_bool = health_check_details if health_check_details is not None else True + if shared_health_manager is not None: try: healthy_endpoints, unhealthy_endpoints = await shared_health_manager.perform_shared_health_check( - model_list=_llm_model_list, details=health_check_details + model_list=_llm_model_list, details=details_bool ) except Exception as e: verbose_proxy_logger.error( diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 625897c3c..a13af1afc 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -25,6 +25,7 @@ model LiteLLM_BudgetTable { organization LiteLLM_OrganizationTable[] // multiple orgs can have the same budget keys LiteLLM_VerificationToken[] // multiple keys can have the same budget end_users LiteLLM_EndUserTable[] // multiple end-users can have the same budget + tags LiteLLM_TagTable[] // multiple tags can have the same budget team_membership LiteLLM_TeamMembership[] // budgets of Users within a Team organization_membership LiteLLM_OrganizationMembership[] // budgets of Users within a Organization } @@ -245,6 +246,20 @@ model LiteLLM_EndUserTable { blocked Boolean @default(false) } +// Track tags with budgets and spend +model LiteLLM_TagTable { + tag_name String @id + description String? + models String[] + model_info Json? // maps model_id to model_name + spend Float @default(0.0) + budget_id String? + litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id]) + created_at DateTime @default(now()) @map("created_at") + created_by String? + updated_at DateTime @default(now()) @updatedAt @map("updated_at") +} + // store proxy config.yaml model LiteLLM_Config { param_name String @id diff --git a/litellm/types/tag_management.py b/litellm/types/tag_management.py index a3615b658..a9b58ace0 100644 --- a/litellm/types/tag_management.py +++ b/litellm/types/tag_management.py @@ -18,11 +18,27 @@ class TagConfig(TagBase): class TagNewRequest(TagBase): - pass + budget_id: Optional[str] = None + # Budget fields - if budget_id is None, create a new budget with these params + max_budget: Optional[float] = None + soft_budget: Optional[float] = None + max_parallel_requests: Optional[int] = None + tpm_limit: Optional[int] = None + rpm_limit: Optional[int] = None + model_max_budget: Optional[Dict] = None + budget_duration: Optional[str] = None class TagUpdateRequest(TagBase): - pass + budget_id: Optional[str] = None + # Budget fields - if provided, will update the budget + max_budget: Optional[float] = None + soft_budget: Optional[float] = None + max_parallel_requests: Optional[int] = None + tpm_limit: Optional[int] = None + rpm_limit: Optional[int] = None + model_max_budget: Optional[Dict] = None + budget_duration: Optional[str] = None class TagDeleteRequest(BaseModel): diff --git a/schema.prisma b/schema.prisma index 625897c3c..a13af1afc 100644 --- a/schema.prisma +++ b/schema.prisma @@ -25,6 +25,7 @@ model LiteLLM_BudgetTable { organization LiteLLM_OrganizationTable[] // multiple orgs can have the same budget keys LiteLLM_VerificationToken[] // multiple keys can have the same budget end_users LiteLLM_EndUserTable[] // multiple end-users can have the same budget + tags LiteLLM_TagTable[] // multiple tags can have the same budget team_membership LiteLLM_TeamMembership[] // budgets of Users within a Team organization_membership LiteLLM_OrganizationMembership[] // budgets of Users within a Organization } @@ -245,6 +246,20 @@ model LiteLLM_EndUserTable { blocked Boolean @default(false) } +// Track tags with budgets and spend +model LiteLLM_TagTable { + tag_name String @id + description String? + models String[] + model_info Json? // maps model_id to model_name + spend Float @default(0.0) + budget_id String? + litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id]) + created_at DateTime @default(now()) @map("created_at") + created_by String? + updated_at DateTime @default(now()) @updatedAt @map("updated_at") +} + // store proxy config.yaml model LiteLLM_Config { param_name String @id diff --git a/tests/test_litellm/proxy/auth/test_auth_checks.py b/tests/test_litellm/proxy/auth/test_auth_checks.py index 9a50986a1..7d4a406c9 100644 --- a/tests/test_litellm/proxy/auth/test_auth_checks.py +++ b/tests/test_litellm/proxy/auth/test_auth_checks.py @@ -26,9 +26,9 @@ from litellm.proxy._types import ( from litellm.proxy.auth.auth_checks import ( ExperimentalUIJWTToken, _can_object_call_vector_stores, + _get_team_db_check, get_user_object, vector_store_access_check, - _get_team_db_check, ) from litellm.proxy.common_utils.encrypt_decrypt_utils import decrypt_value_helper from litellm.utils import get_utc_datetime @@ -567,3 +567,141 @@ def test_can_object_call_model_no_access_to_alias_or_underlying(): assert exc_info.value.type == ProxyErrorTypes.key_model_access_denied assert "key not allowed to access model" in str(exc_info.value.message) assert "my-fake-gpt" in str(exc_info.value.message) + + +# Tag Budget Enforcement Tests + + +@pytest.mark.asyncio +async def test_get_tag_objects_batch(): + """ + Test batch fetching of tags validates: + - Cached tags are fetched from cache (no DB call for them) + - Uncached tags are fetched in ONE batch DB query + - After fetching, uncached tags are cached + """ + from litellm.proxy._types import LiteLLM_TagTable + from litellm.proxy.auth.auth_checks import get_tag_objects_batch + + mock_prisma = MagicMock() + mock_cache = MagicMock() + mock_proxy_logging = MagicMock() + + # Simulate 5 tags: 2 cached, 3 uncached + tag_names = ["cached-1", "uncached-1", "cached-2", "uncached-2", "uncached-3"] + + # Mock cached tags + cached_tag_1 = { + "tag_name": "cached-1", + "spend": 10.0, + "models": [], + "litellm_budget_table": None, + } + cached_tag_2 = { + "tag_name": "cached-2", + "spend": 20.0, + "models": [], + "litellm_budget_table": None, + } + + # Mock DB response for uncached tags + uncached_tag_1 = MagicMock() + uncached_tag_1.tag_name = "uncached-1" + uncached_tag_1.spend = 30.0 + uncached_tag_1.models = [] + uncached_tag_1.litellm_budget_table = None + uncached_tag_1.dict = MagicMock( + return_value={ + "tag_name": "uncached-1", + "spend": 30.0, + "models": [], + "litellm_budget_table": None, + } + ) + + uncached_tag_2 = MagicMock() + uncached_tag_2.tag_name = "uncached-2" + uncached_tag_2.spend = 40.0 + uncached_tag_2.models = [] + uncached_tag_2.litellm_budget_table = None + uncached_tag_2.dict = MagicMock( + return_value={ + "tag_name": "uncached-2", + "spend": 40.0, + "models": [], + "litellm_budget_table": None, + } + ) + + uncached_tag_3 = MagicMock() + uncached_tag_3.tag_name = "uncached-3" + uncached_tag_3.spend = 50.0 + uncached_tag_3.models = [] + uncached_tag_3.litellm_budget_table = None + uncached_tag_3.dict = MagicMock( + return_value={ + "tag_name": "uncached-3", + "spend": 50.0, + "models": [], + "litellm_budget_table": None, + } + ) + + # Mock cache behavior - return cached tags, None for uncached + async def mock_get_cache(key): + if key == "tag:cached-1": + return cached_tag_1 + elif key == "tag:cached-2": + return cached_tag_2 + else: + return None + + mock_cache.async_get_cache = AsyncMock(side_effect=mock_get_cache) + mock_cache.async_set_cache = AsyncMock() + + # Mock DB to return all uncached tags in ONE query + mock_prisma.db.litellm_tagtable.find_many = AsyncMock( + return_value=[uncached_tag_1, uncached_tag_2, uncached_tag_3] + ) + + # Call batch fetch + tag_objects = await get_tag_objects_batch( + tag_names=tag_names, + prisma_client=mock_prisma, + user_api_key_cache=mock_cache, + proxy_logging_obj=mock_proxy_logging, + ) + + # Verify results + assert len(tag_objects) == 5 + assert "cached-1" in tag_objects + assert "cached-2" in tag_objects + assert "uncached-1" in tag_objects + assert "uncached-2" in tag_objects + assert "uncached-3" in tag_objects + + # Verify cached tags have correct values + assert tag_objects["cached-1"].spend == 10.0 + assert tag_objects["cached-2"].spend == 20.0 + + # Verify uncached tags have correct values + assert tag_objects["uncached-1"].spend == 30.0 + assert tag_objects["uncached-2"].spend == 40.0 + assert tag_objects["uncached-3"].spend == 50.0 + + # Verify DB was called ONCE with all 3 uncached tags + mock_prisma.db.litellm_tagtable.find_many.assert_called_once() + call_args = mock_prisma.db.litellm_tagtable.find_many.call_args + assert call_args.kwargs["where"]["tag_name"]["in"] == [ + "uncached-1", + "uncached-2", + "uncached-3", + ] + + # Verify uncached tags were cached after fetching + assert mock_cache.async_set_cache.call_count == 3 + cache_calls = mock_cache.async_set_cache.call_args_list + cached_keys = [call.kwargs["key"] for call in cache_calls] + assert "tag:uncached-1" in cached_keys + assert "tag:uncached-2" in cached_keys + assert "tag:uncached-3" in cached_keys diff --git a/tests/test_litellm/proxy/common_utils/test_http_parsing_utils.py b/tests/test_litellm/proxy/common_utils/test_http_parsing_utils.py index 721bcfc37..cffed0211 100644 --- a/tests/test_litellm/proxy/common_utils/test_http_parsing_utils.py +++ b/tests/test_litellm/proxy/common_utils/test_http_parsing_utils.py @@ -14,13 +14,17 @@ sys.path.insert( import litellm +from litellm.proxy._types import ProxyException from litellm.proxy.common_utils.http_parsing_utils import ( _read_request_body, + _safe_get_request_headers, _safe_get_request_parsed_body, + _safe_get_request_query_params, _safe_set_request_parsed_body, get_form_data, + get_request_body, + get_tags_from_request_body, ) -from litellm.proxy._types import ProxyException @pytest.mark.asyncio @@ -285,3 +289,96 @@ async def test_get_form_data(): # Note: In a real MultiDict, both values would be present # But in our mock dictionary the second value overwrites the first assert "segment" in result["timestamp_granularities"] + + +def test_get_tags_from_request_body_with_metadata_tags(): + """ + Test that tags are correctly extracted from request body metadata. + """ + request_body = { + "model": "gpt-4", + "metadata": { + "tags": ["tag1", "tag2", "tag3"] + } + } + + result = get_tags_from_request_body(request_body=request_body) + + assert result == ["tag1", "tag2", "tag3"] + + +def test_get_tags_from_request_body_with_litellm_metadata_tags(): + """ + Test that tags are correctly extracted from request body when using litellm_metadata. + """ + request_body = { + "model": "gpt-4", + "litellm_metadata": { + "tags": ["tag1", "tag2", "tag3"] + } + } + + result = get_tags_from_request_body(request_body=request_body) + + assert result == ["tag1", "tag2", "tag3"] + + +def test_get_tags_from_request_body_with_root_tags(): + """ + Test that tags are correctly extracted from root level of request body. + """ + request_body = { + "model": "gpt-4", + "tags": ["tag1", "tag2"] + } + + result = get_tags_from_request_body(request_body=request_body) + + assert result == ["tag1", "tag2"] + + +def test_get_tags_from_request_body_with_combined_tags(): + """ + Test that tags from both metadata and root level are combined. + """ + request_body = { + "model": "gpt-4", + "metadata": { + "tags": ["tag1", "tag2"] + }, + "tags": ["tag3", "tag4"] + } + + result = get_tags_from_request_body(request_body=request_body) + + assert result == ["tag1", "tag2", "tag3", "tag4"] + + +def test_get_tags_from_request_body_filters_non_strings(): + """ + Test that non-string values in tags list are filtered out. + """ + request_body = { + "model": "gpt-4", + "metadata": { + "tags": ["tag1", 123, "tag2", None, "tag3", {"nested": "dict"}] + } + } + + result = get_tags_from_request_body(request_body=request_body) + + assert result == ["tag1", "tag2", "tag3"] + + +def test_get_tags_from_request_body_no_tags(): + """ + Test that empty list is returned when no tags are present. + """ + request_body = { + "model": "gpt-4", + "metadata": {} + } + + result = get_tags_from_request_body(request_body=request_body) + + assert result == [] diff --git a/tests/test_litellm/proxy/db/test_db_spend_update_writer.py b/tests/test_litellm/proxy/db/test_db_spend_update_writer.py index dfa075cbf..435b67e33 100644 --- a/tests/test_litellm/proxy/db/test_db_spend_update_writer.py +++ b/tests/test_litellm/proxy/db/test_db_spend_update_writer.py @@ -135,3 +135,127 @@ async def test_update_daily_spend_with_null_entity_id(): assert create_data["api_requests"] == 1 assert create_data["successful_requests"] == 1 assert create_data["failed_requests"] == 0 + + +# Tag Spend Tracking Tests + + +@pytest.mark.asyncio +async def test_update_tag_db_with_valid_tags(): + """ + Test that _update_tag_db correctly processes valid tags and adds them to the spend update queue. + """ + from litellm.proxy._types import Litellm_EntityType, SpendUpdateQueueItem + + writer = DBSpendUpdateWriter() + mock_prisma = MagicMock() + response_cost = 0.05 + request_tags = '["prod-tag", "test-tag"]' + + writer.spend_update_queue.add_update = AsyncMock() + + await writer._update_tag_db( + response_cost=response_cost, + request_tags=request_tags, + prisma_client=mock_prisma, + ) + + assert writer.spend_update_queue.add_update.call_count == 2 + + first_call_args = writer.spend_update_queue.add_update.call_args_list[0][1] + assert first_call_args["update"]["entity_type"] == Litellm_EntityType.TAG + assert first_call_args["update"]["entity_id"] == "prod-tag" + assert first_call_args["update"]["response_cost"] == response_cost + + second_call_args = writer.spend_update_queue.add_update.call_args_list[1][1] + assert second_call_args["update"]["entity_type"] == Litellm_EntityType.TAG + assert second_call_args["update"]["entity_id"] == "test-tag" + assert second_call_args["update"]["response_cost"] == response_cost + + +@pytest.mark.asyncio +async def test_update_tag_db_with_list_input(): + """ + Test that _update_tag_db correctly handles tags passed as a list instead of JSON string. + """ + writer = DBSpendUpdateWriter() + mock_prisma = MagicMock() + response_cost = 0.1 + request_tags = ["tag1", "tag2", "tag3"] + + writer.spend_update_queue.add_update = AsyncMock() + + await writer._update_tag_db( + response_cost=response_cost, + request_tags=request_tags, + prisma_client=mock_prisma, + ) + + assert writer.spend_update_queue.add_update.call_count == 3 + + +@pytest.mark.asyncio +async def test_update_tag_db_with_no_tags(): + """ + Test that _update_tag_db handles None and empty tags gracefully. + """ + writer = DBSpendUpdateWriter() + mock_prisma = MagicMock() + response_cost = 0.05 + + writer.spend_update_queue.add_update = AsyncMock() + + await writer._update_tag_db( + response_cost=response_cost, + request_tags=None, + prisma_client=mock_prisma, + ) + assert writer.spend_update_queue.add_update.call_count == 0 + + await writer._update_tag_db( + response_cost=response_cost, + request_tags=[], + prisma_client=mock_prisma, + ) + assert writer.spend_update_queue.add_update.call_count == 0 + + +@pytest.mark.asyncio +async def test_update_tag_db_with_invalid_json(): + """ + Test that _update_tag_db handles invalid JSON gracefully. + """ + writer = DBSpendUpdateWriter() + mock_prisma = MagicMock() + response_cost = 0.05 + request_tags = '{"invalid": json}' + + writer.spend_update_queue.add_update = AsyncMock() + + await writer._update_tag_db( + response_cost=response_cost, + request_tags=request_tags, + prisma_client=mock_prisma, + ) + + assert writer.spend_update_queue.add_update.call_count == 0 + + +@pytest.mark.asyncio +async def test_update_tag_db_without_prisma_client(): + """ + Test that _update_tag_db returns early when prisma_client is None. + """ + writer = DBSpendUpdateWriter() + response_cost = 0.05 + request_tags = '["tag1"]' + + writer.spend_update_queue.add_update = AsyncMock() + + await writer._update_tag_db( + response_cost=response_cost, + request_tags=request_tags, + prisma_client=None, + ) + + assert writer.spend_update_queue.add_update.call_count == 0 diff --git a/tests/test_litellm/proxy/test_proxy_server.py b/tests/test_litellm/proxy/test_proxy_server.py index 21930d116..5af7ac5f3 100644 --- a/tests/test_litellm/proxy/test_proxy_server.py +++ b/tests/test_litellm/proxy/test_proxy_server.py @@ -2184,3 +2184,107 @@ def test_should_load_db_object_with_supported_db_objects(): ) assert proxy_config._should_load_db_object(object_type="prompts") is True assert proxy_config._should_load_db_object(object_type="model_cost_map") is True + + +@pytest.mark.asyncio +async def test_tag_cache_update_called(): + """ + Test that update_cache updates tag cache when tags are provided. + """ + from litellm.caching.caching import DualCache + from litellm.proxy.proxy_server import user_api_key_cache + + cache = DualCache() + + setattr( + litellm.proxy.proxy_server, + "user_api_key_cache", + cache, + ) + + mock_tag_obj = { + "tag_name": "test-tag", + "spend": 10.0, + } + + with patch.object(cache, "async_get_cache", new=AsyncMock(return_value=mock_tag_obj)) as mock_get_cache: + with patch.object(cache, "async_set_cache_pipeline", new=AsyncMock()) as mock_set_cache: + await litellm.proxy.proxy_server.update_cache( + token=None, + user_id=None, + end_user_id=None, + team_id=None, + response_cost=5.0, + parent_otel_span=None, + tags=["test-tag"], + ) + + await asyncio.sleep(0.1) + + mock_get_cache.assert_awaited_once_with(key="tag:test-tag") + mock_set_cache.assert_awaited_once() + + call_args = mock_set_cache.call_args + cache_list = call_args.kwargs["cache_list"] + + assert len(cache_list) == 1 + cache_key, cache_value = cache_list[0] + assert cache_key == "tag:test-tag" + assert cache_value["spend"] == 15.0 + + +@pytest.mark.asyncio +async def test_tag_cache_update_multiple_tags(): + """ + Test that multiple tags are updated in cache. + """ + from litellm.caching.caching import DualCache + from litellm.proxy.proxy_server import user_api_key_cache + + cache = DualCache() + + setattr( + litellm.proxy.proxy_server, + "user_api_key_cache", + cache, + ) + + mock_tag1_obj = {"tag_name": "tag1", "spend": 10.0} + mock_tag2_obj = {"tag_name": "tag2", "spend": 20.0} + + async def mock_get_cache_side_effect(key): + if key == "tag:tag1": + return mock_tag1_obj + elif key == "tag:tag2": + return mock_tag2_obj + return None + + with patch.object(cache, "async_get_cache", new=AsyncMock(side_effect=mock_get_cache_side_effect)) as mock_get_cache: + with patch.object(cache, "async_set_cache_pipeline", new=AsyncMock()) as mock_set_cache: + await litellm.proxy.proxy_server.update_cache( + token=None, + user_id=None, + end_user_id=None, + team_id=None, + response_cost=5.0, + parent_otel_span=None, + tags=["tag1", "tag2"], + ) + + await asyncio.sleep(0.1) + + assert mock_get_cache.call_count == 2 + mock_set_cache.assert_awaited_once() + + call_args = mock_set_cache.call_args + cache_list = call_args.kwargs["cache_list"] + + assert len(cache_list) == 2 + + tag_updates = {cache_key: cache_value for cache_key, cache_value in cache_list} + assert "tag:tag1" in tag_updates + assert "tag:tag2" in tag_updates + assert tag_updates["tag:tag1"]["spend"] == 15.0 + assert tag_updates["tag:tag2"]["spend"] == 25.0 + + diff --git a/ui/litellm-dashboard/src/components/tag_management/components/CreateTagModal.test.tsx b/ui/litellm-dashboard/src/components/tag_management/components/CreateTagModal.test.tsx new file mode 100644 index 000000000..a6100fa8f --- /dev/null +++ b/ui/litellm-dashboard/src/components/tag_management/components/CreateTagModal.test.tsx @@ -0,0 +1,81 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { render, screen, waitFor } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import CreateTagModal from "./CreateTagModal"; + +describe("CreateTagModal", () => { + const mockOnCancel = vi.fn(); + const mockOnSubmit = vi.fn(); + const mockAvailableModels = [ + { + model_name: "gpt-4", + litellm_params: { + model: "gpt-4", + }, + model_info: { + id: "model-123", + }, + }, + { + model_name: "claude-3", + litellm_params: { + model: "claude-3", + }, + model_info: { + id: "model-456", + }, + }, + ]; + + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("should submit form with correct values when user creates a tag", async () => { + /** + * Tests that filling out the form and clicking submit calls onSubmit with the correct values. + * This is the core functionality of the CreateTagModal component. + */ + const user = userEvent.setup(); + + render( + + ); + + // Wait for modal to be visible + await waitFor(() => { + expect(screen.getByText("Create New Tag")).toBeInTheDocument(); + }); + + // Fill in the tag name + const tagNameInput = screen.getByLabelText("Tag Name"); + await user.type(tagNameInput, "production-tag"); + + // Fill in the description + const descriptionTextarea = screen.getByLabelText("Description"); + await user.type(descriptionTextarea, "Tag for production environment"); + + // Submit the form + const createButton = screen.getByText("Create Tag"); + await user.click(createButton); + + // Verify onSubmit was called with correct values + await waitFor(() => { + expect(mockOnSubmit).toHaveBeenCalledWith( + expect.objectContaining({ + tag_name: "production-tag", + description: "Tag for production environment", + }) + ); + }); + + // Verify onCancel was not called + expect(mockOnCancel).not.toHaveBeenCalled(); + }); +}); + diff --git a/ui/litellm-dashboard/src/components/tag_management/components/CreateTagModal.tsx b/ui/litellm-dashboard/src/components/tag_management/components/CreateTagModal.tsx new file mode 100644 index 000000000..4d1909abd --- /dev/null +++ b/ui/litellm-dashboard/src/components/tag_management/components/CreateTagModal.tsx @@ -0,0 +1,153 @@ +import React from "react"; +import { Button, TextInput, Accordion, AccordionHeader, AccordionBody, Title } from "@tremor/react"; +import { Modal, Form, Select as Select2, Tooltip, Input } from "antd"; +import { InfoCircleOutlined } from "@ant-design/icons"; +import NumericalInput from "../../shared/numerical_input"; +import BudgetDurationDropdown from "../../common_components/budget_duration_dropdown"; + +interface ModelInfo { + model_name: string; + litellm_params: { + model: string; + }; + model_info: { + id: string; + }; +} + +interface CreateTagModalProps { + visible: boolean; + onCancel: () => void; + onSubmit: (values: any) => void; + availableModels: ModelInfo[]; +} + +const CreateTagModal: React.FC = ({ + visible, + onCancel, + onSubmit, + availableModels, +}) => { + const [form] = Form.useForm(); + + const handleFinish = (values: any) => { + onSubmit(values); + form.resetFields(); + }; + + const handleCancel = () => { + form.resetFields(); + onCancel(); + }; + + return ( + +
+ + + + + + + + + + Allowed Models{" "} + + + + + } + name="allowed_llms" + > + + {availableModels.map((model) => ( + +
+ {model.model_name} + ({model.model_info.id}) +
+
+ ))} +
+
+ + + + Budget & Rate Limits (Optional) + + + + Max Budget (USD){" "} + + + + + } + name="max_budget" + > + + + + Reset Budget{" "} + + + + + } + name="budget_duration" + > + form.setFieldValue("budget_duration", value)} /> + + +
+

+ TPM/RPM limits for tags are not currently supported. If you need this feature, please{" "} + + create a GitHub issue + + . +

+
+
+
+ +
+ +
+
+
+ ); +}; + +export default CreateTagModal; + diff --git a/ui/litellm-dashboard/src/components/tag_management/index.tsx b/ui/litellm-dashboard/src/components/tag_management/index.tsx index 1b95ef7d9..c8c4eab6f 100644 --- a/ui/litellm-dashboard/src/components/tag_management/index.tsx +++ b/ui/litellm-dashboard/src/components/tag_management/index.tsx @@ -1,14 +1,13 @@ import React, { useState, useEffect } from "react"; -import { Icon, Button, Col, Text, Grid, TextInput } from "@tremor/react"; +import { Icon, Button, Col, Text, Grid } from "@tremor/react"; import { RefreshIcon } from "@heroicons/react/outline"; -import { Modal, Form, Select as Select2, Tooltip, Input } from "antd"; -import { InfoCircleOutlined } from "@ant-design/icons"; import TagInfoView from "./tag_info"; import { modelInfoCall } from "../networking"; import { tagCreateCall, tagListCall, tagDeleteCall } from "../networking"; import { Tag } from "./types"; import TagTable from "./TagTable"; import NotificationsManager from "../molecules/notifications_manager"; +import CreateTagModal from "./components/CreateTagModal"; interface ModelInfo { model_name: string; @@ -34,7 +33,6 @@ const TagManagement: React.FC = ({ accessToken, userID, userRole }) => const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false); const [tagToDelete, setTagToDelete] = useState(null); const [lastRefreshed, setLastRefreshed] = useState(""); - const [form] = Form.useForm(); const [availableModels, setAvailableModels] = useState([]); const fetchTags = async () => { @@ -62,10 +60,14 @@ const TagManagement: React.FC = ({ accessToken, userID, userRole }) => name: formValues.tag_name, description: formValues.description, models: formValues.allowed_llms, + max_budget: formValues.max_budget, + soft_budget: formValues.soft_budget, + tpm_limit: formValues.tpm_limit, + rpm_limit: formValues.rpm_limit, + budget_duration: formValues.budget_duration, }); NotificationsManager.success("Tag created successfully"); setIsCreateModalVisible(false); - form.resetFields(); fetchTags(); } catch (error) { console.error("Error creating tag:", error); @@ -173,63 +175,12 @@ const TagManagement: React.FC = ({ accessToken, userID, userRole }) => {/* Create Tag Modal */} - { - setIsCreateModalVisible(false); - form.resetFields(); - }} - > -
- - - - - - - - - - Allowed Models{" "} - - - - - } - name="allowed_llms" - > - - {availableModels.map((model) => ( - -
- {model.model_name} - ({model.model_info.id}) -
-
- ))} -
-
- -
- -
-
-
+ onCancel={() => setIsCreateModalVisible(false)} + onSubmit={handleCreate} + availableModels={availableModels} + /> {/* Delete Confirmation Modal */} {isDeleteModalOpen && ( diff --git a/ui/litellm-dashboard/src/components/tag_management/tag_info.tsx b/ui/litellm-dashboard/src/components/tag_management/tag_info.tsx index 627e21176..60cde134e 100644 --- a/ui/litellm-dashboard/src/components/tag_management/tag_info.tsx +++ b/ui/litellm-dashboard/src/components/tag_management/tag_info.tsx @@ -1,5 +1,5 @@ import React, { useState, useEffect } from "react"; -import { Card, Text, Title, Button, Badge } from "@tremor/react"; +import { Card, Text, Title, Button, Badge, Accordion, AccordionHeader, AccordionBody, Title as TremorTitle } from "@tremor/react"; import { Form, Input, Select as Select2, Tooltip } from "antd"; import { InfoCircleOutlined } from "@ant-design/icons"; import { fetchUserModels } from "../organisms/create_key_button"; @@ -7,6 +7,11 @@ import { getModelDisplayName } from "../key_team_helpers/fetch_available_models_ import { tagInfoCall, tagUpdateCall } from "../networking"; import { Tag } from "./types"; import NotificationsManager from "../molecules/notifications_manager"; +import NumericalInput from "../shared/numerical_input"; +import BudgetDurationDropdown from "../common_components/budget_duration_dropdown"; +import { copyToClipboard as utilCopyToClipboard } from "@/utils/dataUtils"; +import { CheckIcon, CopyIcon } from "lucide-react"; +import { Button as AntdButton } from "antd"; interface TagInfoViewProps { tagId: string; @@ -21,6 +26,17 @@ const TagInfoView: React.FC = ({ tagId, onClose, accessToken, const [tagDetails, setTagDetails] = useState(null); const [isEditing, setIsEditing] = useState(editTag); const [userModels, setUserModels] = useState([]); + const [copiedStates, setCopiedStates] = useState>({}); + + const copyToClipboard = async (text: string | null | undefined, key: string) => { + const success = await utilCopyToClipboard(text); + if (success) { + setCopiedStates((prev) => ({ ...prev, [key]: true })); + setTimeout(() => { + setCopiedStates((prev) => ({ ...prev, [key]: false })); + }, 2000); + } + }; const fetchTagDetails = async () => { if (!accessToken) return; @@ -34,6 +50,8 @@ const TagInfoView: React.FC = ({ tagId, onClose, accessToken, name: tagData.name, description: tagData.description, models: tagData.models, + max_budget: tagData.litellm_budget_table?.max_budget, + budget_duration: tagData.litellm_budget_table?.budget_duration, }); } } @@ -62,6 +80,10 @@ const TagInfoView: React.FC = ({ tagId, onClose, accessToken, name: values.name, description: values.description, models: values.models, + max_budget: values.max_budget, + tpm_limit: values.tpm_limit, + rpm_limit: values.rpm_limit, + budget_duration: values.budget_duration, }); NotificationsManager.success("Tag updated successfully"); setIsEditing(false); @@ -83,7 +105,23 @@ const TagInfoView: React.FC = ({ tagId, onClose, accessToken, - Tag Name: {tagDetails.name} +
+ Tag Name: + + {tagDetails.name} + + : } + onClick={() => copyToClipboard(tagDetails.name, "tag-name")} + className={`transition-all duration-200 ${ + copiedStates["tag-name"] + ? "text-green-600 bg-green-50 border-green-200" + : "text-gray-500 hover:text-gray-700 hover:bg-gray-100" + }`} + /> +
{tagDetails.description || "No description"} {is_admin && !isEditing && } @@ -120,6 +158,56 @@ const TagInfoView: React.FC = ({ tagId, onClose, accessToken, + + + Budget & Rate Limits + + + + Max Budget (USD){" "} + + + + + } + name="max_budget" + > + + + + + Reset Budget{" "} + + + + + } + name="budget_duration" + > + form.setFieldValue("budget_duration", value)} /> + + +
+

+ TPM/RPM limits for tags are not currently supported. If you need this feature, please{" "} + + create a GitHub issue + + . +

+
+
+
+
@@ -142,7 +230,7 @@ const TagInfoView: React.FC = ({ tagId, onClose, accessToken,
Allowed LLMs
- {tagDetails.models.length === 0 ? ( + {!tagDetails.models || tagDetails.models.length === 0 ? ( All Models ) : ( tagDetails.models.map((modelId) => ( @@ -163,6 +251,38 @@ const TagInfoView: React.FC = ({ tagId, onClose, accessToken,
+ + {tagDetails.litellm_budget_table && ( + + Budget & Rate Limits +
+ {tagDetails.litellm_budget_table.max_budget !== undefined && tagDetails.litellm_budget_table.max_budget !== null && ( +
+ Max Budget + ${tagDetails.litellm_budget_table.max_budget} +
+ )} + {tagDetails.litellm_budget_table.budget_duration && ( +
+ Budget Duration + {tagDetails.litellm_budget_table.budget_duration} +
+ )} + {tagDetails.litellm_budget_table.tpm_limit !== undefined && tagDetails.litellm_budget_table.tpm_limit !== null && ( +
+ TPM Limit + {tagDetails.litellm_budget_table.tpm_limit.toLocaleString()} +
+ )} + {tagDetails.litellm_budget_table.rpm_limit !== undefined && tagDetails.litellm_budget_table.rpm_limit !== null && ( +
+ RPM Limit + {tagDetails.litellm_budget_table.rpm_limit.toLocaleString()} +
+ )} +
+
+ )}
)} diff --git a/ui/litellm-dashboard/src/components/tag_management/types.tsx b/ui/litellm-dashboard/src/components/tag_management/types.tsx index 254be890e..3cf17545f 100644 --- a/ui/litellm-dashboard/src/components/tag_management/types.tsx +++ b/ui/litellm-dashboard/src/components/tag_management/types.tsx @@ -7,6 +7,15 @@ export interface Tag { updated_at: string; created_by?: string; updated_by?: string; + litellm_budget_table?: { + max_budget?: number; + soft_budget?: number; + tpm_limit?: number; + rpm_limit?: number; + max_parallel_requests?: number; + budget_duration?: string; + model_max_budget?: any; + }; } export interface TagInfoRequest { @@ -17,12 +26,22 @@ export interface TagNewRequest { name: string; description?: string; models: string[]; + max_budget?: number; + soft_budget?: number; + tpm_limit?: number; + rpm_limit?: number; + budget_duration?: string; } export interface TagUpdateRequest { name: string; description?: string; models: string[]; + max_budget?: number; + soft_budget?: number; + tpm_limit?: number; + rpm_limit?: number; + budget_duration?: string; } export interface TagDeleteRequest {