(feat) Team level model-specific tpm/rpm limits + working key-level validation of tpm/rpm limit when assigned to team (#15513)

* fix(support-model-specific-tpm/rpm-limits): Allows setting rate limits by tpm/rpm for models by team

* fix(key_management_endpoints.py): enforce guaranteed throughput with key-level model tpm/rpm limits, when team-level tpm/rpm limits are set

* test: add unit testing

* fix: fix minor linting errors

* fix: refactor
This commit is contained in:
Krish Dholakia
2025-10-18 13:14:04 -07:00
committed by GitHub
parent 46d754a0f9
commit 4e141df03a
6 changed files with 299 additions and 37 deletions
+4
View File
@@ -1284,6 +1284,8 @@ class NewTeamRequest(TeamBase):
prompts: Optional[List[str]] = None
object_permission: Optional[LiteLLM_ObjectPermissionBase] = None
allowed_passthrough_routes: Optional[list] = None
model_rpm_limit: Optional[Dict[str, int]] = None
model_tpm_limit: Optional[Dict[str, int]] = None
team_member_budget: Optional[float] = (
None # allow user to set a budget for all team members
)
@@ -1340,6 +1342,8 @@ class UpdateTeamRequest(LiteLLMPydanticObjectBase):
team_member_tpm_limit: Optional[int] = None
team_member_key_duration: Optional[str] = None
allowed_passthrough_routes: Optional[list] = None
model_rpm_limit: Optional[Dict[str, int]] = None
model_tpm_limit: Optional[Dict[str, int]] = None
class ResetTeamBudgetRequest(LiteLLMPydanticObjectBase):
+16
View File
@@ -453,6 +453,22 @@ def get_key_model_tpm_limit(
return None
def get_team_model_rpm_limit(
user_api_key_dict: UserAPIKeyAuth,
) -> Optional[Dict[str, int]]:
if user_api_key_dict.team_metadata:
return user_api_key_dict.team_metadata.get("model_rpm_limit")
return None
def get_team_model_tpm_limit(
user_api_key_dict: UserAPIKeyAuth,
) -> Optional[Dict[str, int]]:
if user_api_key_dict.team_metadata:
return user_api_key_dict.team_metadata.get("model_tpm_limit")
return None
def is_pass_through_provider_route(route: str) -> bool:
PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES = [
"vertex-ai",
@@ -103,6 +103,7 @@ return results
REDIS_CLUSTER_SLOTS = 16384
REDIS_NODE_HASHTAG_NAME = "all_keys"
class RateLimitDescriptorRateLimitObject(TypedDict, total=False):
requests_per_unit: Optional[int]
tokens_per_unit: Optional[int]
@@ -157,15 +158,17 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
def _is_redis_cluster(self) -> bool:
"""
Check if the dual cache is using Redis cluster.
Returns:
bool: True if using Redis cluster, False otherwise.
"""
from litellm.caching.redis_cluster_cache import RedisClusterCache
return (
self.internal_usage_cache.dual_cache.redis_cache is not None
and isinstance(self.internal_usage_cache.dual_cache.redis_cache, RedisClusterCache)
and isinstance(
self.internal_usage_cache.dual_cache.redis_cache, RedisClusterCache
)
)
async def in_memory_cache_sliding_window(
@@ -310,7 +313,7 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
)
return RateLimitResponse(overall_code=overall_code, statuses=statuses)
def keyslot_for_redis_cluster(self, key: str) -> int:
"""
Compute the Redis Cluster slot for a given key.
@@ -325,34 +328,34 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
Returns:
int: The slot number (0-16383).
"""
# Handle hash tags: use substring between { and }
start = key.find('{')
start = key.find("{")
if start != -1:
end = key.find('}', start + 1)
end = key.find("}", start + 1)
if end != -1 and end != start + 1:
key = key[start + 1:end]
key = key[start + 1 : end]
# Compute CRC16 and mod 16384
crc = binascii.crc_hqx(key.encode('utf-8'), 0)
crc = binascii.crc_hqx(key.encode("utf-8"), 0)
return crc % REDIS_CLUSTER_SLOTS
def _group_keys_by_hash_tag(self, keys: List[str]) -> Dict[str, List[str]]:
"""
Group keys by their Redis hash tag to ensure cluster compatibility.
For Redis clusters, uses slot calculation to group keys that belong to the same slot.
For regular Redis, no grouping is needed - all keys can be processed together.
"""
groups: Dict[str, List[str]] = {}
# Use slot calculation for Redis clusters only
if self._is_redis_cluster():
for key in keys:
slot = self.keyslot_for_redis_cluster(key)
slot_key = f"slot_{slot}"
if slot_key not in groups:
groups[slot_key] = []
groups[slot_key].append(key)
@@ -414,7 +417,7 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
Check if any of the rate limit descriptors should be rate limited.
Returns a RateLimitResponse with the overall code and status for each descriptor.
Uses batch operations for Redis to improve performance.
Args:
descriptors: List of rate limit descriptors to check
parent_otel_span: Optional OpenTelemetry span for tracing
@@ -499,7 +502,7 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
parent_otel_span=parent_otel_span,
local_only=False, # Check Redis too
)
# For keys that don't exist yet, set them to 0
if cache_values is None:
cache_values = []
@@ -546,6 +549,66 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
)
return rate_limit_response
def _add_model_per_key_rate_limit_descriptor(
self,
user_api_key_dict: UserAPIKeyAuth,
requested_model: Optional[str],
descriptors: List[RateLimitDescriptor],
) -> None:
"""
Add model-specific rate limit descriptor for API key if applicable.
Args:
user_api_key_dict: User API key authentication dictionary
requested_model: The model being requested
descriptors: List of rate limit descriptors to append to
"""
from litellm.proxy.auth.auth_utils import (
get_key_model_rpm_limit,
get_key_model_tpm_limit,
)
if not requested_model:
return
_tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict)
_rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict)
if _tpm_limit_for_key_model is None and _rpm_limit_for_key_model is None:
return
_tpm_limit_for_key_model = _tpm_limit_for_key_model or {}
_rpm_limit_for_key_model = _rpm_limit_for_key_model or {}
# Check if model has any rate limits configured
should_check_rate_limit = (
requested_model in _tpm_limit_for_key_model
or requested_model in _rpm_limit_for_key_model
)
if not should_check_rate_limit:
return
# Get model-specific limits
model_specific_tpm_limit: Optional[int] = _tpm_limit_for_key_model.get(
requested_model
)
model_specific_rpm_limit: Optional[int] = _rpm_limit_for_key_model.get(
requested_model
)
descriptors.append(
RateLimitDescriptor(
key="model_per_key",
value=f"{user_api_key_dict.api_key}:{requested_model}",
rate_limit={
"requests_per_unit": model_specific_rpm_limit,
"tokens_per_unit": model_specific_tpm_limit,
"window_size": self.window_size,
},
)
)
def _should_enforce_rate_limit(
self,
limit_type: Optional[str],
@@ -626,8 +689,8 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
Returns list of descriptors for API key, user, team, team member, end user, and model-specific limits.
"""
from litellm.proxy.auth.auth_utils import (
get_key_model_rpm_limit,
get_key_model_tpm_limit,
get_team_model_rpm_limit,
get_team_model_tpm_limit,
)
descriptors = []
@@ -732,29 +795,43 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
# Model rate limits
requested_model = data.get("model", None)
if requested_model and (
get_key_model_tpm_limit(user_api_key_dict) is not None
or get_key_model_rpm_limit(user_api_key_dict) is not None
self._add_model_per_key_rate_limit_descriptor(
user_api_key_dict=user_api_key_dict,
requested_model=requested_model,
descriptors=descriptors,
)
if (
get_team_model_rpm_limit(user_api_key_dict) is not None
or get_team_model_tpm_limit(user_api_key_dict) is not None
):
_tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict) or {}
_rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict) or {}
_tpm_limit_for_team_model = (
get_team_model_tpm_limit(user_api_key_dict) or {}
)
_rpm_limit_for_team_model = (
get_team_model_rpm_limit(user_api_key_dict) or {}
)
should_check_rate_limit = False
if requested_model in _tpm_limit_for_key_model:
if requested_model in _tpm_limit_for_team_model:
should_check_rate_limit = True
elif requested_model in _rpm_limit_for_key_model:
elif requested_model in _rpm_limit_for_team_model:
should_check_rate_limit = True
if should_check_rate_limit:
model_specific_tpm_limit: Optional[int] = None
model_specific_rpm_limit: Optional[int] = None
if requested_model in _tpm_limit_for_key_model:
model_specific_tpm_limit = _tpm_limit_for_key_model[requested_model]
if requested_model in _rpm_limit_for_key_model:
model_specific_rpm_limit = _rpm_limit_for_key_model[requested_model]
model_specific_tpm_limit = None
model_specific_rpm_limit = None
if requested_model in _tpm_limit_for_team_model:
model_specific_tpm_limit = _tpm_limit_for_team_model[
requested_model
]
if requested_model in _rpm_limit_for_team_model:
model_specific_rpm_limit = _rpm_limit_for_team_model[
requested_model
]
descriptors.append(
RateLimitDescriptor(
key="model_per_key",
value=f"{user_api_key_dict.api_key}:{requested_model}",
key="model_per_team",
value=f"{user_api_key_dict.team_id}:{requested_model}",
rate_limit={
"requests_per_unit": model_specific_rpm_limit,
"tokens_per_unit": model_specific_tpm_limit,
@@ -1164,6 +1241,15 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
total_tokens=total_tokens,
)
)
if model_group and user_api_key_team_id:
pipeline_operations.extend(
self._create_pipeline_operations(
key="model_per_team",
value=f"{user_api_key_team_id}:{model_group}",
rate_limit_type="tokens",
total_tokens=total_tokens,
)
)
# Execute all increments in a single pipeline
if pipeline_operations:
@@ -667,7 +667,8 @@ def check_team_key_model_specific_limits(
if data.model_rpm_limit is not None:
for model, rpm_limit in data.model_rpm_limit.items():
if (
model_specific_rpm_limit.get(model, 0) + rpm_limit
team_table.rpm_limit is not None
and model_specific_rpm_limit.get(model, 0) + rpm_limit
> team_table.rpm_limit
):
raise HTTPException(
@@ -687,7 +688,7 @@ def check_team_key_model_specific_limits(
):
raise HTTPException(
status_code=400,
detail=f"Allocated RPM limit={model_specific_rpm_limit.get(model, 0)} + Key RPM limit={rpm_limit} is greater than team RPM limit={team_model_specific_rpm_limit.get(model, 0)}",
detail=f"Allocated RPM limit={model_specific_rpm_limit.get(model, 0)} + Key RPM limit={rpm_limit} is greater than team RPM limit={team_model_specific_rpm_limit}",
)
if data.model_tpm_limit is not None:
for model, tpm_limit in data.model_tpm_limit.items():
@@ -300,6 +300,8 @@ async def new_team( # noqa: PLR0915
- members_with_roles: List[{"role": "admin" or "user", "user_id": "<user-id>"}] - A list of users and their roles in the team. Get user_id when making a new user via `/user/new`.
- team_member_permissions: Optional[List[str]] - A list of routes that non-admin team members can access. example: ["/key/generate", "/key/update", "/key/delete"]
- metadata: Optional[dict] - Metadata for team, store information for team. Example metadata = {"extra_info": "some info"}
- model_rpm_limit: Optional[Dict[str, int]] - The RPM (Requests Per Minute) limit for this team - applied across all keys for this team.
- model_tpm_limit: Optional[Dict[str, int]] - The TPM (Tokens Per Minute) limit for this team - applied across all keys for this team.
- tpm_limit: Optional[int] - The TPM (Tokens Per Minute) limit for this team - all keys with this team_id will have at max this TPM limit
- rpm_limit: Optional[int] - The RPM (Requests Per Minute) limit for this team - all keys associated with this team_id will have at max this RPM limit
- max_budget: Optional[float] - The maximum budget allocated to the team - all keys for this team_id will have at max this max_budget
@@ -425,7 +425,7 @@ async def test_key_generation_with_object_permission(monkeypatch):
async def test_key_generation_with_mcp_tool_permissions(monkeypatch):
"""
Test that /key/generate correctly handles mcp_tool_permissions in object_permission.
This test verifies that:
1. mcp_tool_permissions is accepted in the object_permission field
2. The field is properly stored in the LiteLLM_ObjectPermissionTable
@@ -490,6 +490,7 @@ async def test_key_generation_with_mcp_tool_permissions(monkeypatch):
# Verify mcp_tool_permissions was stored (serialized to JSON string for GraphQL compatibility)
assert "mcp_tool_permissions" in created_permission_data
import json
assert json.loads(created_permission_data["mcp_tool_permissions"]) == {
"server_1": ["tool1", "tool2", "tool3"]
}
@@ -1816,6 +1817,160 @@ def test_check_team_key_model_specific_limits_rpm_overallocation():
)
def test_check_team_key_model_specific_limits_team_model_rpm_overallocation():
"""
Test check_team_key_model_specific_limits when team has model-specific RPM limits
in metadata and key allocation would exceed those limits.
This tests the scenario where team_table.metadata["model_rpm_limit"] is set
with per-model limits, not just a global team RPM limit.
"""
# Create existing keys with model-specific RPM limits
existing_key1 = LiteLLM_VerificationToken(
token="test-token-1",
user_id="test-user-1",
team_id="test-team-789",
metadata={
"model_rpm_limit": {
"gpt-4": 300,
"gpt-3.5-turbo": 200,
}
},
)
existing_key2 = LiteLLM_VerificationToken(
token="test-token-2",
user_id="test-user-2",
team_id="test-team-789",
metadata={
"model_rpm_limit": {
"gpt-4": 250,
}
},
)
keys = [existing_key1, existing_key2]
# Create team table with model-specific RPM limits in metadata
team_table = LiteLLM_TeamTableCachedObj(
team_id="test-team-789",
team_alias="test-team",
tpm_limit=None,
rpm_limit=None,
max_budget=100.0,
spend=0.0,
models=[],
blocked=False,
members_with_roles=[],
metadata={
"model_rpm_limit": {
"gpt-4": 700, # Team-level model-specific limit for gpt-4
"gpt-3.5-turbo": 500,
}
},
)
# Create request that would exceed team's model-specific RPM limits
# Existing gpt-4: 300 + 250 = 550, New: 200, Total: 750 > 700 (team model-specific limit)
data = GenerateKeyRequest(
model_rpm_limit={
"gpt-4": 200, # This would cause overallocation against team model-specific limit
},
model_tpm_limit=None,
)
# Should raise HTTPException for team model-specific RPM overallocation
with pytest.raises(HTTPException) as exc_info:
check_team_key_model_specific_limits(
keys=keys,
team_table=team_table,
data=data,
)
assert exc_info.value.status_code == 400
assert (
"Allocated RPM limit=550 + Key RPM limit=200 is greater than team RPM limit=700"
in str(exc_info.value.detail)
)
def test_check_team_key_model_specific_limits_team_model_tpm_overallocation():
"""
Test check_team_key_model_specific_limits when team has model-specific TPM limits
in metadata and key allocation would exceed those limits.
This tests the scenario where team_table.metadata["model_tpm_limit"] is set
with per-model limits, not just a global team TPM limit.
"""
# Create existing keys with model-specific TPM limits
existing_key1 = LiteLLM_VerificationToken(
token="test-token-1",
user_id="test-user-1",
team_id="test-team-101",
metadata={
"model_tpm_limit": {
"gpt-4": 5000,
"claude-3": 3000,
}
},
)
existing_key2 = LiteLLM_VerificationToken(
token="test-token-2",
user_id="test-user-2",
team_id="test-team-101",
metadata={
"model_tpm_limit": {
"gpt-4": 3500,
}
},
)
keys = [existing_key1, existing_key2]
# Create team table with model-specific TPM limits in metadata
team_table = LiteLLM_TeamTableCachedObj(
team_id="test-team-101",
team_alias="test-team",
tpm_limit=None,
rpm_limit=None,
max_budget=100.0,
spend=0.0,
models=[],
blocked=False,
members_with_roles=[],
metadata={
"model_tpm_limit": {
"gpt-4": 10000, # Team-level model-specific limit for gpt-4
"claude-3": 8000,
}
},
)
# Create request that would exceed team's model-specific TPM limits
# Existing gpt-4: 5000 + 3500 = 8500, New: 2000, Total: 10500 > 10000 (team model-specific limit)
data = GenerateKeyRequest(
model_rpm_limit=None,
model_tpm_limit={
"gpt-4": 2000, # This would cause overallocation against team model-specific limit
},
)
# Should raise HTTPException for team model-specific TPM overallocation
with pytest.raises(HTTPException) as exc_info:
check_team_key_model_specific_limits(
keys=keys,
team_table=team_table,
data=data,
)
assert exc_info.value.status_code == 400
assert (
"Allocated TPM limit=8500 + Key TPM limit=2000 is greater than team TPM limit=10000"
in str(exc_info.value.detail)
)
@pytest.mark.asyncio
async def test_generate_key_with_object_permission():
"""
@@ -1876,9 +2031,7 @@ async def test_generate_key_with_object_permission():
with patch(
"litellm.proxy.proxy_server.prisma_client",
mock_prisma_client,
), patch(
"litellm.proxy.proxy_server.llm_router", None
), patch(
), patch("litellm.proxy.proxy_server.llm_router", None), patch(
"litellm.proxy.proxy_server.premium_user",
False,
), patch(