mirror of
https://github.com/onyx-dot-app/litellm.git
synced 2026-07-01 20:44:04 -04:00
(feat) Team level model-specific tpm/rpm limits + working key-level validation of tpm/rpm limit when assigned to team (#15513)
* fix(support-model-specific-tpm/rpm-limits): Allows setting rate limits by tpm/rpm for models by team * fix(key_management_endpoints.py): enforce guaranteed throughput with key-level model tpm/rpm limits, when team-level tpm/rpm limits are set * test: add unit testing * fix: fix minor linting errors * fix: refactor
This commit is contained in:
@@ -1284,6 +1284,8 @@ class NewTeamRequest(TeamBase):
|
||||
prompts: Optional[List[str]] = None
|
||||
object_permission: Optional[LiteLLM_ObjectPermissionBase] = None
|
||||
allowed_passthrough_routes: Optional[list] = None
|
||||
model_rpm_limit: Optional[Dict[str, int]] = None
|
||||
model_tpm_limit: Optional[Dict[str, int]] = None
|
||||
team_member_budget: Optional[float] = (
|
||||
None # allow user to set a budget for all team members
|
||||
)
|
||||
@@ -1340,6 +1342,8 @@ class UpdateTeamRequest(LiteLLMPydanticObjectBase):
|
||||
team_member_tpm_limit: Optional[int] = None
|
||||
team_member_key_duration: Optional[str] = None
|
||||
allowed_passthrough_routes: Optional[list] = None
|
||||
model_rpm_limit: Optional[Dict[str, int]] = None
|
||||
model_tpm_limit: Optional[Dict[str, int]] = None
|
||||
|
||||
|
||||
class ResetTeamBudgetRequest(LiteLLMPydanticObjectBase):
|
||||
|
||||
@@ -453,6 +453,22 @@ def get_key_model_tpm_limit(
|
||||
return None
|
||||
|
||||
|
||||
def get_team_model_rpm_limit(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> Optional[Dict[str, int]]:
|
||||
if user_api_key_dict.team_metadata:
|
||||
return user_api_key_dict.team_metadata.get("model_rpm_limit")
|
||||
return None
|
||||
|
||||
|
||||
def get_team_model_tpm_limit(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> Optional[Dict[str, int]]:
|
||||
if user_api_key_dict.team_metadata:
|
||||
return user_api_key_dict.team_metadata.get("model_tpm_limit")
|
||||
return None
|
||||
|
||||
|
||||
def is_pass_through_provider_route(route: str) -> bool:
|
||||
PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES = [
|
||||
"vertex-ai",
|
||||
|
||||
@@ -103,6 +103,7 @@ return results
|
||||
REDIS_CLUSTER_SLOTS = 16384
|
||||
REDIS_NODE_HASHTAG_NAME = "all_keys"
|
||||
|
||||
|
||||
class RateLimitDescriptorRateLimitObject(TypedDict, total=False):
|
||||
requests_per_unit: Optional[int]
|
||||
tokens_per_unit: Optional[int]
|
||||
@@ -157,15 +158,17 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
|
||||
def _is_redis_cluster(self) -> bool:
|
||||
"""
|
||||
Check if the dual cache is using Redis cluster.
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if using Redis cluster, False otherwise.
|
||||
"""
|
||||
from litellm.caching.redis_cluster_cache import RedisClusterCache
|
||||
|
||||
|
||||
return (
|
||||
self.internal_usage_cache.dual_cache.redis_cache is not None
|
||||
and isinstance(self.internal_usage_cache.dual_cache.redis_cache, RedisClusterCache)
|
||||
and isinstance(
|
||||
self.internal_usage_cache.dual_cache.redis_cache, RedisClusterCache
|
||||
)
|
||||
)
|
||||
|
||||
async def in_memory_cache_sliding_window(
|
||||
@@ -310,7 +313,7 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
|
||||
)
|
||||
|
||||
return RateLimitResponse(overall_code=overall_code, statuses=statuses)
|
||||
|
||||
|
||||
def keyslot_for_redis_cluster(self, key: str) -> int:
|
||||
"""
|
||||
Compute the Redis Cluster slot for a given key.
|
||||
@@ -325,34 +328,34 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
|
||||
Returns:
|
||||
int: The slot number (0-16383).
|
||||
|
||||
|
||||
|
||||
"""
|
||||
# Handle hash tags: use substring between { and }
|
||||
start = key.find('{')
|
||||
start = key.find("{")
|
||||
if start != -1:
|
||||
end = key.find('}', start + 1)
|
||||
end = key.find("}", start + 1)
|
||||
if end != -1 and end != start + 1:
|
||||
key = key[start + 1:end]
|
||||
key = key[start + 1 : end]
|
||||
|
||||
# Compute CRC16 and mod 16384
|
||||
crc = binascii.crc_hqx(key.encode('utf-8'), 0)
|
||||
crc = binascii.crc_hqx(key.encode("utf-8"), 0)
|
||||
return crc % REDIS_CLUSTER_SLOTS
|
||||
|
||||
def _group_keys_by_hash_tag(self, keys: List[str]) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Group keys by their Redis hash tag to ensure cluster compatibility.
|
||||
|
||||
|
||||
For Redis clusters, uses slot calculation to group keys that belong to the same slot.
|
||||
For regular Redis, no grouping is needed - all keys can be processed together.
|
||||
"""
|
||||
groups: Dict[str, List[str]] = {}
|
||||
|
||||
|
||||
# Use slot calculation for Redis clusters only
|
||||
if self._is_redis_cluster():
|
||||
for key in keys:
|
||||
slot = self.keyslot_for_redis_cluster(key)
|
||||
slot_key = f"slot_{slot}"
|
||||
|
||||
|
||||
if slot_key not in groups:
|
||||
groups[slot_key] = []
|
||||
groups[slot_key].append(key)
|
||||
@@ -414,7 +417,7 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
|
||||
Check if any of the rate limit descriptors should be rate limited.
|
||||
Returns a RateLimitResponse with the overall code and status for each descriptor.
|
||||
Uses batch operations for Redis to improve performance.
|
||||
|
||||
|
||||
Args:
|
||||
descriptors: List of rate limit descriptors to check
|
||||
parent_otel_span: Optional OpenTelemetry span for tracing
|
||||
@@ -499,7 +502,7 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
|
||||
parent_otel_span=parent_otel_span,
|
||||
local_only=False, # Check Redis too
|
||||
)
|
||||
|
||||
|
||||
# For keys that don't exist yet, set them to 0
|
||||
if cache_values is None:
|
||||
cache_values = []
|
||||
@@ -546,6 +549,66 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
|
||||
)
|
||||
return rate_limit_response
|
||||
|
||||
def _add_model_per_key_rate_limit_descriptor(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
requested_model: Optional[str],
|
||||
descriptors: List[RateLimitDescriptor],
|
||||
) -> None:
|
||||
"""
|
||||
Add model-specific rate limit descriptor for API key if applicable.
|
||||
|
||||
Args:
|
||||
user_api_key_dict: User API key authentication dictionary
|
||||
requested_model: The model being requested
|
||||
descriptors: List of rate limit descriptors to append to
|
||||
"""
|
||||
from litellm.proxy.auth.auth_utils import (
|
||||
get_key_model_rpm_limit,
|
||||
get_key_model_tpm_limit,
|
||||
)
|
||||
|
||||
if not requested_model:
|
||||
return
|
||||
|
||||
_tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict)
|
||||
_rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict)
|
||||
|
||||
if _tpm_limit_for_key_model is None and _rpm_limit_for_key_model is None:
|
||||
return
|
||||
|
||||
_tpm_limit_for_key_model = _tpm_limit_for_key_model or {}
|
||||
_rpm_limit_for_key_model = _rpm_limit_for_key_model or {}
|
||||
|
||||
# Check if model has any rate limits configured
|
||||
should_check_rate_limit = (
|
||||
requested_model in _tpm_limit_for_key_model
|
||||
or requested_model in _rpm_limit_for_key_model
|
||||
)
|
||||
|
||||
if not should_check_rate_limit:
|
||||
return
|
||||
|
||||
# Get model-specific limits
|
||||
model_specific_tpm_limit: Optional[int] = _tpm_limit_for_key_model.get(
|
||||
requested_model
|
||||
)
|
||||
model_specific_rpm_limit: Optional[int] = _rpm_limit_for_key_model.get(
|
||||
requested_model
|
||||
)
|
||||
|
||||
descriptors.append(
|
||||
RateLimitDescriptor(
|
||||
key="model_per_key",
|
||||
value=f"{user_api_key_dict.api_key}:{requested_model}",
|
||||
rate_limit={
|
||||
"requests_per_unit": model_specific_rpm_limit,
|
||||
"tokens_per_unit": model_specific_tpm_limit,
|
||||
"window_size": self.window_size,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
def _should_enforce_rate_limit(
|
||||
self,
|
||||
limit_type: Optional[str],
|
||||
@@ -626,8 +689,8 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
|
||||
Returns list of descriptors for API key, user, team, team member, end user, and model-specific limits.
|
||||
"""
|
||||
from litellm.proxy.auth.auth_utils import (
|
||||
get_key_model_rpm_limit,
|
||||
get_key_model_tpm_limit,
|
||||
get_team_model_rpm_limit,
|
||||
get_team_model_tpm_limit,
|
||||
)
|
||||
|
||||
descriptors = []
|
||||
@@ -732,29 +795,43 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
|
||||
|
||||
# Model rate limits
|
||||
requested_model = data.get("model", None)
|
||||
if requested_model and (
|
||||
get_key_model_tpm_limit(user_api_key_dict) is not None
|
||||
or get_key_model_rpm_limit(user_api_key_dict) is not None
|
||||
self._add_model_per_key_rate_limit_descriptor(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
requested_model=requested_model,
|
||||
descriptors=descriptors,
|
||||
)
|
||||
|
||||
if (
|
||||
get_team_model_rpm_limit(user_api_key_dict) is not None
|
||||
or get_team_model_tpm_limit(user_api_key_dict) is not None
|
||||
):
|
||||
_tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict) or {}
|
||||
_rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict) or {}
|
||||
_tpm_limit_for_team_model = (
|
||||
get_team_model_tpm_limit(user_api_key_dict) or {}
|
||||
)
|
||||
_rpm_limit_for_team_model = (
|
||||
get_team_model_rpm_limit(user_api_key_dict) or {}
|
||||
)
|
||||
should_check_rate_limit = False
|
||||
if requested_model in _tpm_limit_for_key_model:
|
||||
if requested_model in _tpm_limit_for_team_model:
|
||||
should_check_rate_limit = True
|
||||
elif requested_model in _rpm_limit_for_key_model:
|
||||
elif requested_model in _rpm_limit_for_team_model:
|
||||
should_check_rate_limit = True
|
||||
|
||||
if should_check_rate_limit:
|
||||
model_specific_tpm_limit: Optional[int] = None
|
||||
model_specific_rpm_limit: Optional[int] = None
|
||||
if requested_model in _tpm_limit_for_key_model:
|
||||
model_specific_tpm_limit = _tpm_limit_for_key_model[requested_model]
|
||||
if requested_model in _rpm_limit_for_key_model:
|
||||
model_specific_rpm_limit = _rpm_limit_for_key_model[requested_model]
|
||||
model_specific_tpm_limit = None
|
||||
model_specific_rpm_limit = None
|
||||
if requested_model in _tpm_limit_for_team_model:
|
||||
model_specific_tpm_limit = _tpm_limit_for_team_model[
|
||||
requested_model
|
||||
]
|
||||
if requested_model in _rpm_limit_for_team_model:
|
||||
model_specific_rpm_limit = _rpm_limit_for_team_model[
|
||||
requested_model
|
||||
]
|
||||
descriptors.append(
|
||||
RateLimitDescriptor(
|
||||
key="model_per_key",
|
||||
value=f"{user_api_key_dict.api_key}:{requested_model}",
|
||||
key="model_per_team",
|
||||
value=f"{user_api_key_dict.team_id}:{requested_model}",
|
||||
rate_limit={
|
||||
"requests_per_unit": model_specific_rpm_limit,
|
||||
"tokens_per_unit": model_specific_tpm_limit,
|
||||
@@ -1164,6 +1241,15 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
)
|
||||
if model_group and user_api_key_team_id:
|
||||
pipeline_operations.extend(
|
||||
self._create_pipeline_operations(
|
||||
key="model_per_team",
|
||||
value=f"{user_api_key_team_id}:{model_group}",
|
||||
rate_limit_type="tokens",
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
# Execute all increments in a single pipeline
|
||||
if pipeline_operations:
|
||||
|
||||
@@ -667,7 +667,8 @@ def check_team_key_model_specific_limits(
|
||||
if data.model_rpm_limit is not None:
|
||||
for model, rpm_limit in data.model_rpm_limit.items():
|
||||
if (
|
||||
model_specific_rpm_limit.get(model, 0) + rpm_limit
|
||||
team_table.rpm_limit is not None
|
||||
and model_specific_rpm_limit.get(model, 0) + rpm_limit
|
||||
> team_table.rpm_limit
|
||||
):
|
||||
raise HTTPException(
|
||||
@@ -687,7 +688,7 @@ def check_team_key_model_specific_limits(
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Allocated RPM limit={model_specific_rpm_limit.get(model, 0)} + Key RPM limit={rpm_limit} is greater than team RPM limit={team_model_specific_rpm_limit.get(model, 0)}",
|
||||
detail=f"Allocated RPM limit={model_specific_rpm_limit.get(model, 0)} + Key RPM limit={rpm_limit} is greater than team RPM limit={team_model_specific_rpm_limit}",
|
||||
)
|
||||
if data.model_tpm_limit is not None:
|
||||
for model, tpm_limit in data.model_tpm_limit.items():
|
||||
|
||||
@@ -300,6 +300,8 @@ async def new_team( # noqa: PLR0915
|
||||
- members_with_roles: List[{"role": "admin" or "user", "user_id": "<user-id>"}] - A list of users and their roles in the team. Get user_id when making a new user via `/user/new`.
|
||||
- team_member_permissions: Optional[List[str]] - A list of routes that non-admin team members can access. example: ["/key/generate", "/key/update", "/key/delete"]
|
||||
- metadata: Optional[dict] - Metadata for team, store information for team. Example metadata = {"extra_info": "some info"}
|
||||
- model_rpm_limit: Optional[Dict[str, int]] - The RPM (Requests Per Minute) limit for this team - applied across all keys for this team.
|
||||
- model_tpm_limit: Optional[Dict[str, int]] - The TPM (Tokens Per Minute) limit for this team - applied across all keys for this team.
|
||||
- tpm_limit: Optional[int] - The TPM (Tokens Per Minute) limit for this team - all keys with this team_id will have at max this TPM limit
|
||||
- rpm_limit: Optional[int] - The RPM (Requests Per Minute) limit for this team - all keys associated with this team_id will have at max this RPM limit
|
||||
- max_budget: Optional[float] - The maximum budget allocated to the team - all keys for this team_id will have at max this max_budget
|
||||
|
||||
@@ -425,7 +425,7 @@ async def test_key_generation_with_object_permission(monkeypatch):
|
||||
async def test_key_generation_with_mcp_tool_permissions(monkeypatch):
|
||||
"""
|
||||
Test that /key/generate correctly handles mcp_tool_permissions in object_permission.
|
||||
|
||||
|
||||
This test verifies that:
|
||||
1. mcp_tool_permissions is accepted in the object_permission field
|
||||
2. The field is properly stored in the LiteLLM_ObjectPermissionTable
|
||||
@@ -490,6 +490,7 @@ async def test_key_generation_with_mcp_tool_permissions(monkeypatch):
|
||||
# Verify mcp_tool_permissions was stored (serialized to JSON string for GraphQL compatibility)
|
||||
assert "mcp_tool_permissions" in created_permission_data
|
||||
import json
|
||||
|
||||
assert json.loads(created_permission_data["mcp_tool_permissions"]) == {
|
||||
"server_1": ["tool1", "tool2", "tool3"]
|
||||
}
|
||||
@@ -1816,6 +1817,160 @@ def test_check_team_key_model_specific_limits_rpm_overallocation():
|
||||
)
|
||||
|
||||
|
||||
def test_check_team_key_model_specific_limits_team_model_rpm_overallocation():
|
||||
"""
|
||||
Test check_team_key_model_specific_limits when team has model-specific RPM limits
|
||||
in metadata and key allocation would exceed those limits.
|
||||
|
||||
This tests the scenario where team_table.metadata["model_rpm_limit"] is set
|
||||
with per-model limits, not just a global team RPM limit.
|
||||
"""
|
||||
# Create existing keys with model-specific RPM limits
|
||||
existing_key1 = LiteLLM_VerificationToken(
|
||||
token="test-token-1",
|
||||
user_id="test-user-1",
|
||||
team_id="test-team-789",
|
||||
metadata={
|
||||
"model_rpm_limit": {
|
||||
"gpt-4": 300,
|
||||
"gpt-3.5-turbo": 200,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
existing_key2 = LiteLLM_VerificationToken(
|
||||
token="test-token-2",
|
||||
user_id="test-user-2",
|
||||
team_id="test-team-789",
|
||||
metadata={
|
||||
"model_rpm_limit": {
|
||||
"gpt-4": 250,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
keys = [existing_key1, existing_key2]
|
||||
|
||||
# Create team table with model-specific RPM limits in metadata
|
||||
team_table = LiteLLM_TeamTableCachedObj(
|
||||
team_id="test-team-789",
|
||||
team_alias="test-team",
|
||||
tpm_limit=None,
|
||||
rpm_limit=None,
|
||||
max_budget=100.0,
|
||||
spend=0.0,
|
||||
models=[],
|
||||
blocked=False,
|
||||
members_with_roles=[],
|
||||
metadata={
|
||||
"model_rpm_limit": {
|
||||
"gpt-4": 700, # Team-level model-specific limit for gpt-4
|
||||
"gpt-3.5-turbo": 500,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Create request that would exceed team's model-specific RPM limits
|
||||
# Existing gpt-4: 300 + 250 = 550, New: 200, Total: 750 > 700 (team model-specific limit)
|
||||
data = GenerateKeyRequest(
|
||||
model_rpm_limit={
|
||||
"gpt-4": 200, # This would cause overallocation against team model-specific limit
|
||||
},
|
||||
model_tpm_limit=None,
|
||||
)
|
||||
|
||||
# Should raise HTTPException for team model-specific RPM overallocation
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
check_team_key_model_specific_limits(
|
||||
keys=keys,
|
||||
team_table=team_table,
|
||||
data=data,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert (
|
||||
"Allocated RPM limit=550 + Key RPM limit=200 is greater than team RPM limit=700"
|
||||
in str(exc_info.value.detail)
|
||||
)
|
||||
|
||||
|
||||
def test_check_team_key_model_specific_limits_team_model_tpm_overallocation():
|
||||
"""
|
||||
Test check_team_key_model_specific_limits when team has model-specific TPM limits
|
||||
in metadata and key allocation would exceed those limits.
|
||||
|
||||
This tests the scenario where team_table.metadata["model_tpm_limit"] is set
|
||||
with per-model limits, not just a global team TPM limit.
|
||||
"""
|
||||
# Create existing keys with model-specific TPM limits
|
||||
existing_key1 = LiteLLM_VerificationToken(
|
||||
token="test-token-1",
|
||||
user_id="test-user-1",
|
||||
team_id="test-team-101",
|
||||
metadata={
|
||||
"model_tpm_limit": {
|
||||
"gpt-4": 5000,
|
||||
"claude-3": 3000,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
existing_key2 = LiteLLM_VerificationToken(
|
||||
token="test-token-2",
|
||||
user_id="test-user-2",
|
||||
team_id="test-team-101",
|
||||
metadata={
|
||||
"model_tpm_limit": {
|
||||
"gpt-4": 3500,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
keys = [existing_key1, existing_key2]
|
||||
|
||||
# Create team table with model-specific TPM limits in metadata
|
||||
team_table = LiteLLM_TeamTableCachedObj(
|
||||
team_id="test-team-101",
|
||||
team_alias="test-team",
|
||||
tpm_limit=None,
|
||||
rpm_limit=None,
|
||||
max_budget=100.0,
|
||||
spend=0.0,
|
||||
models=[],
|
||||
blocked=False,
|
||||
members_with_roles=[],
|
||||
metadata={
|
||||
"model_tpm_limit": {
|
||||
"gpt-4": 10000, # Team-level model-specific limit for gpt-4
|
||||
"claude-3": 8000,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Create request that would exceed team's model-specific TPM limits
|
||||
# Existing gpt-4: 5000 + 3500 = 8500, New: 2000, Total: 10500 > 10000 (team model-specific limit)
|
||||
data = GenerateKeyRequest(
|
||||
model_rpm_limit=None,
|
||||
model_tpm_limit={
|
||||
"gpt-4": 2000, # This would cause overallocation against team model-specific limit
|
||||
},
|
||||
)
|
||||
|
||||
# Should raise HTTPException for team model-specific TPM overallocation
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
check_team_key_model_specific_limits(
|
||||
keys=keys,
|
||||
team_table=team_table,
|
||||
data=data,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert (
|
||||
"Allocated TPM limit=8500 + Key TPM limit=2000 is greater than team TPM limit=10000"
|
||||
in str(exc_info.value.detail)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_key_with_object_permission():
|
||||
"""
|
||||
@@ -1876,9 +2031,7 @@ async def test_generate_key_with_object_permission():
|
||||
with patch(
|
||||
"litellm.proxy.proxy_server.prisma_client",
|
||||
mock_prisma_client,
|
||||
), patch(
|
||||
"litellm.proxy.proxy_server.llm_router", None
|
||||
), patch(
|
||||
), patch("litellm.proxy.proxy_server.llm_router", None), patch(
|
||||
"litellm.proxy.proxy_server.premium_user",
|
||||
False,
|
||||
), patch(
|
||||
|
||||
Reference in New Issue
Block a user