diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 640df7cb0..00922e714 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -970,12 +970,10 @@ class NewUserResponse(GenerateKeyResponse): updated_at: Optional[datetime] = None -class UpdateUserRequest(GenerateRequestBase): - # Note: the defaults of all Params here MUST BE NONE - # else they will get overwritten - user_id: Optional[str] = None +class UpdateUserRequestNoUserIDorEmail( + GenerateRequestBase +): # shared with BulkUpdateUserRequest password: Optional[str] = None - user_email: Optional[str] = None spend: Optional[float] = None metadata: Optional[dict] = None user_role: Optional[ @@ -988,6 +986,13 @@ class UpdateUserRequest(GenerateRequestBase): ] = None max_budget: Optional[float] = None + +class UpdateUserRequest(UpdateUserRequestNoUserIDorEmail): + # Note: the defaults of all Params here MUST BE NONE + # else they will get overwritten + user_id: Optional[str] = None + user_email: Optional[str] = None + @model_validator(mode="before") @classmethod def check_user_info(cls, values): diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index c059b4b58..1610ed262 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -12,6 +12,7 @@ These are members of a Team on LiteLLM """ import asyncio +import json import traceback import uuid from datetime import datetime, timezone @@ -744,7 +745,9 @@ def _process_keys_for_user_info( return returned_keys -def _update_internal_user_params(data_json: dict, data: UpdateUserRequest) -> dict: +def _update_internal_user_params( + data_json: dict, data: Union[UpdateUserRequest, UpdateUserRequestNoUserIDorEmail] +) -> dict: non_default_values = {} for k, v in data_json.items(): if ( @@ -1015,6 +1018,69 @@ async def user_update( ) +async def bulk_update_processed_users( + users_to_update: List[UpdateUserRequest], + user_api_key_dict: UserAPIKeyAuth, + litellm_changed_by: Optional[str] = None, +) -> BulkUpdateUserResponse: + results: List[UserUpdateResult] = [] + successful_updates = 0 + failed_updates = 0 + + # Process each user update independently + try: + for user_request in users_to_update: + try: + response = await _update_single_user_helper( + user_request=user_request, + user_api_key_dict=user_api_key_dict, + litellm_changed_by=litellm_changed_by, + ) + # Record success + results.append( + UserUpdateResult( + user_id=( + response.get("user_id") + if response + else user_request.user_id + ), + user_email=user_request.user_email, + success=True, + updated_user=response, + ) + ) + successful_updates += 1 + except Exception as e: + verbose_proxy_logger.exception( + f"Failed to update user {user_request.user_id or user_request.user_email}: {e}" + ) + # Record failure + error_message = str(e) + verbose_proxy_logger.error( + f"Failed to update user {user_request.user_id or user_request.user_email}: {error_message}" + ) + + results.append( + UserUpdateResult( + user_id=user_request.user_id, + user_email=user_request.user_email, + success=False, + error=error_message, + ) + ) + failed_updates += 1 + + return BulkUpdateUserResponse( + results=results, + total_requested=len(users_to_update), + successful_updates=successful_updates, + failed_updates=failed_updates, + ) + except Exception as e: + verbose_proxy_logger.exception(f"Failed to update users: {e}") + raise HTTPException(status_code=500, detail={"error": str(e)}) + + @router.post( "/user/bulk_update", tags=["Internal User management"], @@ -1037,7 +1103,9 @@ async def bulk_user_update( is processed independently - if some updates fail, others will still succeed. Parameters: - - users: List[UpdateUserRequest] - List of user update requests + - users: Optional[List[UpdateUserRequest]] - List of specific user update requests + - all_users: Optional[bool] - Set to true to update all users in the system + - user_updates: Optional[UpdateUserRequest] - Updates to apply when all_users=True Returns: - results: List of individual update results @@ -1045,7 +1113,7 @@ async def bulk_user_update( - successful_updates: Number of successful updates - failed_updates: Number of failed updates - Example request: + Example request for specific users: ```bash curl --location 'http://0.0.0.0:4000/user/bulk_update' \ --header 'Authorization: Bearer sk-1234' \ @@ -1065,8 +1133,22 @@ async def bulk_user_update( ] }' ``` + + Example request for all users: + ```bash + curl --location 'http://0.0.0.0:4000/user/bulk_update' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "all_users": true, + "user_updates": { + "user_role": "internal_user", + "max_budget": 50.0 + } + }' + ``` """ - from litellm.proxy.proxy_server import prisma_client + from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client if prisma_client is None: raise HTTPException( @@ -1074,69 +1156,130 @@ async def bulk_user_update( detail={"error": "Database not connected"}, ) - if not data.users: + # Determine the list of users to update + users_to_update: Union[ + List[UpdateUserRequest], List[UpdateUserRequestNoUserIDorEmail] + ] = [] + + if data.all_users and data.user_updates: + # Optimized path for updating all users directly in database + all_users_in_db = await prisma_client.db.litellm_usertable.find_many( + order={"created_at": "desc"} + ) + + if not all_users_in_db: + raise HTTPException( + status_code=400, + detail={"error": "No users found to update"}, + ) + + # Limit batch size to prevent overwhelming the system + MAX_BATCH_SIZE = 500 # Increased limit for all-users operations + if len(all_users_in_db) > MAX_BATCH_SIZE: + raise HTTPException( + status_code=400, + detail={ + "error": f"Maximum {MAX_BATCH_SIZE} users can be updated at once. Found {len(all_users_in_db)} users." + }, + ) + + # Apply update transformations (reuse existing logic) + data_json: dict = data.user_updates.model_dump(exclude_unset=True) + non_default_values = _update_internal_user_params( + data_json=data_json, data=data.user_updates + ) + + # Remove user identification fields since we're updating by user_id + non_default_values.pop("user_id", None) + non_default_values.pop("user_email", None) + + successful_updates = 0 + failed_updates = 0 + results: List[UserUpdateResult] = [] + + try: + # Perform bulk database update + await prisma_client.db.litellm_usertable.update_many( + where={}, data=non_default_values # Update all users + ) + + # Create individual success results + for user in all_users_in_db: + results.append( + UserUpdateResult( + user_id=user.user_id, + user_email=user.user_email, + success=True, + updated_user={"user_id": user.user_id, **non_default_values}, + ) + ) + successful_updates += 1 + + # Create single audit log entry for bulk operation + try: + asyncio.create_task( + UserManagementEventHooks.create_internal_user_audit_log( + user_id=user_api_key_dict.user_id or "", + action="updated", + litellm_changed_by=litellm_changed_by + or user_api_key_dict.user_id, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, + before_value=f"Updated {len(all_users_in_db)} users", + after_value=json.dumps(non_default_values), + ) + ) + except Exception as audit_error: + verbose_proxy_logger.warning( + f"Failed to create bulk audit log: {audit_error}" + ) + + except Exception as e: + verbose_proxy_logger.exception(f"Failed to perform bulk update: {e}") + # Fall back to individual updates if bulk update fails + for user in all_users_in_db: + user_update_request = data.user_updates.model_copy() + user_update_request.user_id = user.user_id + users_to_update.append(user_update_request) # type: ignore + + if successful_updates > 0: + return BulkUpdateUserResponse( + results=results, + total_requested=len(all_users_in_db), + successful_updates=successful_updates, + failed_updates=failed_updates, + ) + + elif data.users: + users_to_update = data.users + else: raise HTTPException( status_code=400, - detail={"error": "At least one user update request is required"}, + detail={ + "error": "Must specify either 'users' for individual updates or 'all_users=True' with 'user_updates' for bulk updates" + }, + ) + + if not users_to_update: + raise HTTPException( + status_code=400, + detail={"error": "No users found to update"}, ) # Limit batch size to prevent overwhelming the system - MAX_BATCH_SIZE = 100 - if len(data.users) > MAX_BATCH_SIZE: + MAX_BATCH_SIZE = 500 # Increased limit for all-users operations + if len(users_to_update) > MAX_BATCH_SIZE: raise HTTPException( status_code=400, - detail={"error": f"Maximum {MAX_BATCH_SIZE} users can be updated at once"}, + detail={ + "error": f"Maximum {MAX_BATCH_SIZE} users can be updated at once. Found {len(users_to_update)} users." + }, ) - results: List[UserUpdateResult] = [] - successful_updates = 0 - failed_updates = 0 - - # Process each user update independently - for user_request in data.users: - try: - response = await _update_single_user_helper( - user_request=user_request, - user_api_key_dict=user_api_key_dict, - litellm_changed_by=litellm_changed_by, - ) - # Record success - results.append( - UserUpdateResult( - user_id=( - response.get("user_id") if response else user_request.user_id - ), - user_email=user_request.user_email, - success=True, - updated_user=response, - ) - ) - successful_updates += 1 - except Exception as e: - verbose_proxy_logger.exception( - f"Failed to update user {user_request.user_id or user_request.user_email}: {e}" - ) - # Record failure - error_message = str(e) - verbose_proxy_logger.error( - f"Failed to update user {user_request.user_id or user_request.user_email}: {error_message}" - ) - - results.append( - UserUpdateResult( - user_id=user_request.user_id, - user_email=user_request.user_email, - success=False, - error=error_message, - ) - ) - failed_updates += 1 - - return BulkUpdateUserResponse( - results=results, - total_requested=len(data.users), - successful_updates=successful_updates, - failed_updates=failed_updates, + return await bulk_update_processed_users( + users_to_update=cast(List[UpdateUserRequest], users_to_update), + user_api_key_dict=user_api_key_dict, + litellm_changed_by=litellm_changed_by, ) diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index 4867a4c59..a1e96610a 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -961,28 +961,39 @@ def team_member_add_duplication_check( obvious duplicates where both user_id and user_email match exactly. """ + invalid_team_members = [] + def _check_member_duplication(member: Member): # Check by user_id if provided if member.user_id is not None: for existing_member in existing_team_row.members_with_roles: if existing_member.user_id == member.user_id: - raise ProxyException( - message=f"User with user_id={member.user_id} already in team. Existing members={existing_team_row.members_with_roles}", - type=ProxyErrorTypes.team_member_already_in_team, - param="user_id", - code="400", - ) + invalid_team_members.append(member) # Check by user_email if provided if member.user_email is not None: for existing_member in existing_team_row.members_with_roles: if existing_member.user_email == member.user_email: - raise ProxyException( - message=f"User with user_email={member.user_email} already in team. Existing members={existing_team_row.members_with_roles}", - type=ProxyErrorTypes.team_member_already_in_team, - param="user_email", - code="400", - ) + invalid_team_members.append(member) + + if isinstance(data.member, list) and len(invalid_team_members) == len(data.member): + raise ProxyException( + message=f"All users are already in team. Existing members={existing_team_row.members_with_roles}", + type=ProxyErrorTypes.team_member_already_in_team, + param="user_email", + code="400", + ) + elif isinstance(data.member, Member) and len(invalid_team_members) == 1: + raise ProxyException( + message=f"User with user_email={data.member.user_email} already in team. Existing members={existing_team_row.members_with_roles}", + type=ProxyErrorTypes.team_member_already_in_team, + param="user_email", + code="400", + ) + elif len(invalid_team_members) > 0: + verbose_proxy_logger.info( + f"Some users are already in team. Existing members={existing_team_row.members_with_roles}. Duplicate members={invalid_team_members}", + ) if isinstance(data.member, Member): _check_member_duplication(data.member) @@ -1617,6 +1628,7 @@ async def bulk_team_member_add( Parameters: - team_id: str - The ID of the team to add members to - members: List[Member] - List of members to add to the team + - all_users: Optional[bool] - Flag to add all users on Proxy to the team - max_budget_in_team: Optional[float] - Maximum budget allocated to each user within the team Returns: @@ -1647,6 +1659,29 @@ async def bulk_team_member_add( }' ``` """ + from litellm.proxy._types import CommonProxyErrors + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException( + status_code=500, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + + if data.all_users: + # get all users from the database + all_users_in_db = await prisma_client.db.litellm_usertable.find_many( + order={"created_at": "desc"} + ) + data.members = [ + Member( + user_id=user.user_id, + user_email=user.user_email, + role="user", + ) + for user in all_users_in_db + ] + if not data.members: raise HTTPException( status_code=400, @@ -1654,7 +1689,7 @@ async def bulk_team_member_add( ) # Limit batch size to prevent overwhelming the system - MAX_BATCH_SIZE = 100 + MAX_BATCH_SIZE = 500 if len(data.members) > MAX_BATCH_SIZE: raise HTTPException( status_code=400, @@ -1686,6 +1721,7 @@ async def bulk_team_member_add( except Exception as e: # If the entire operation fails, mark all members as failed + verbose_proxy_logger.exception(e) error_message = str(e) results = [ TeamMemberAddResult( diff --git a/litellm/types/proxy/management_endpoints/internal_user_endpoints.py b/litellm/types/proxy/management_endpoints/internal_user_endpoints.py index a3a68a04c..6023094a9 100644 --- a/litellm/types/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/types/proxy/management_endpoints/internal_user_endpoints.py @@ -1,9 +1,13 @@ from typing import Any, Dict, List, Literal, Optional, Union from fastapi import HTTPException -from pydantic import BaseModel, EmailStr +from pydantic import BaseModel, EmailStr, field_validator -from litellm.proxy._types import LiteLLM_UserTableWithKeyCount, UpdateUserRequest +from litellm.proxy._types import ( + LiteLLM_UserTableWithKeyCount, + UpdateUserRequest, + UpdateUserRequestNoUserIDorEmail, +) class UserListResponse(BaseModel): @@ -21,7 +25,41 @@ class UserListResponse(BaseModel): class BulkUpdateUserRequest(BaseModel): """Request for bulk user updates""" - users: List[UpdateUserRequest] # List of user update requests + users: Optional[List[UpdateUserRequest]] = ( + None # List of specific user update requests + ) + all_users: Optional[bool] = False # Flag to update all users + user_updates: Optional[UpdateUserRequestNoUserIDorEmail] = ( + None # Updates to apply to all users when all_users=True + ) + + @field_validator("users", "all_users", "user_updates") + @classmethod + def validate_request(cls, v, info): + # Get all field values for validation + values = info.data if hasattr(info, "data") else {} + + # After all fields are set, validate the combination + if ( + info.field_name == "user_updates" + ): # This is the last field, do validation here + users = values.get("users") + all_users = values.get("all_users", False) + user_updates = v + + # Must specify either users list OR all_users with user_updates + if not users and not (all_users and user_updates): + raise ValueError( + "Must specify either 'users' for individual updates or 'all_users=True' with 'user_updates' for bulk updates" + ) + + # Cannot specify both users list and all_users + if users and all_users: + raise ValueError( + "Cannot specify both 'users' and 'all_users=True'. Choose one approach." + ) + + return v class UserUpdateResult(BaseModel): diff --git a/litellm/types/proxy/management_endpoints/team_endpoints.py b/litellm/types/proxy/management_endpoints/team_endpoints.py index 581655f0e..957e5d60e 100644 --- a/litellm/types/proxy/management_endpoints/team_endpoints.py +++ b/litellm/types/proxy/management_endpoints/team_endpoints.py @@ -56,7 +56,8 @@ class BulkTeamMemberAddRequest(BaseModel): """Request for bulk team member addition""" team_id: str - members: List[Member] # List of members to add + members: Optional[List[Member]] = None # List of members to add + all_users: Optional[bool] = False # Flag to add all users on Proxy to the team max_budget_in_team: Optional[float] = None diff --git a/test_bulk_update_all_users.py b/test_bulk_update_all_users.py new file mode 100644 index 000000000..a281f914d --- /dev/null +++ b/test_bulk_update_all_users.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +""" +Test script for the new bulk update "all users" functionality. + +This script demonstrates how to use the enhanced bulk_update endpoint +to update all users in the system at once. +""" + +import requests +import json + +# Configuration +PROXY_BASE_URL = "http://localhost:4000" +ACCESS_TOKEN = "sk-1234" # Replace with your actual access token + + +def test_bulk_update_specific_users(): + """Test the existing functionality - updating specific users.""" + print("=== Testing bulk update for specific users ===") + + url = f"{PROXY_BASE_URL}/user/bulk_update" + headers = { + "Authorization": f"Bearer {ACCESS_TOKEN}", + "Content-Type": "application/json", + } + + # Example payload for updating specific users + payload = { + "users": [ + {"user_id": "user1", "user_role": "internal_user", "max_budget": 100.0}, + { + "user_email": "user2@example.com", + "user_role": "internal_user_viewer", + "max_budget": 50.0, + }, + ] + } + + try: + response = requests.post(url, headers=headers, json=payload) + print(f"Status Code: {response.status_code}") + print(f"Response: {response.json()}") + except Exception as e: + print(f"Error: {e}") + + +def test_bulk_update_all_users(): + """Test the new functionality - updating all users.""" + print("\n=== Testing bulk update for ALL users ===") + + url = f"{PROXY_BASE_URL}/user/bulk_update" + headers = { + "Authorization": f"Bearer {ACCESS_TOKEN}", + "Content-Type": "application/json", + } + + # Example payload for updating ALL users + payload = { + "all_users": True, + "user_updates": {"user_role": "internal_user", "max_budget": 75.0}, + } + + try: + response = requests.post(url, headers=headers, json=payload) + print(f"Status Code: {response.status_code}") + print(f"Response: {response.json()}") + except Exception as e: + print(f"Error: {e}") + + +def test_validation_errors(): + """Test validation errors for invalid payloads.""" + print("\n=== Testing validation errors ===") + + url = f"{PROXY_BASE_URL}/user/bulk_update" + headers = { + "Authorization": f"Bearer {ACCESS_TOKEN}", + "Content-Type": "application/json", + } + + # Test 1: Empty payload + print("Test 1: Empty payload") + try: + response = requests.post(url, headers=headers, json={}) + print(f"Status Code: {response.status_code}") + print(f"Response: {response.json()}") + except Exception as e: + print(f"Error: {e}") + + # Test 2: Both users and all_users specified + print("\nTest 2: Both users and all_users specified") + try: + payload = { + "users": [{"user_id": "user1", "user_role": "internal_user"}], + "all_users": True, + "user_updates": {"user_role": "internal_user"}, + } + response = requests.post(url, headers=headers, json=payload) + print(f"Status Code: {response.status_code}") + print(f"Response: {response.json()}") + except Exception as e: + print(f"Error: {e}") + + # Test 3: all_users=True but no user_updates + print("\nTest 3: all_users=True but no user_updates") + try: + payload = {"all_users": True} + response = requests.post(url, headers=headers, json=payload) + print(f"Status Code: {response.status_code}") + print(f"Response: {response.json()}") + except Exception as e: + print(f"Error: {e}") + + +if __name__ == "__main__": + print("Bulk Update All Users Test Script") + print("==================================") + + # Note: Comment out tests as needed + # test_bulk_update_specific_users() + # test_bulk_update_all_users() # BE CAREFUL with this one! + test_validation_errors() + + print("\n✅ Test script completed!") + print("\nNOTE: The 'test_bulk_update_all_users()' function is commented out") + print("to prevent accidentally updating all users. Uncomment it carefully!") diff --git a/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py index efc3db49e..820d8bd76 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py @@ -38,6 +38,11 @@ from litellm.proxy.management_helpers.team_member_permission_checks import ( ) from litellm.proxy.proxy_server import app from litellm.router import Router +from litellm.types.proxy.management_endpoints.team_endpoints import ( + BulkTeamMemberAddRequest, + BulkTeamMemberAddResponse, + TeamMemberAddResult, +) # Setup TestClient client = TestClient(app) @@ -1172,3 +1177,294 @@ async def test_update_team_team_member_budget_not_passed_to_db(): print( "✅ All test cases passed: team_member_budget is properly excluded from database update operations" ) + + +@pytest.mark.asyncio +async def test_bulk_team_member_add_success(): + """ + Test bulk_team_member_add with successful addition of multiple members + """ + from litellm.proxy._types import ( + LiteLLM_TeamMembership, + LiteLLM_UserTable, + TeamAddMemberResponse, + ) + from litellm.proxy.management_endpoints.team_endpoints import bulk_team_member_add + + # Create test data + test_members = [ + Member(user_email="user1@example.com", role="user"), + Member(user_email="user2@example.com", role="admin"), + ] + + bulk_request = BulkTeamMemberAddRequest( + team_id="test-team-123", + members=test_members, + max_budget_in_team=100.0, + ) + + # Mock successful team_member_add response using MagicMock for simplicity + mock_user_1 = MagicMock(spec=LiteLLM_UserTable) + mock_user_1.user_id = "user-1" + mock_user_1.user_email = "user1@example.com" + mock_user_1.model_dump.return_value = { + "user_id": "user-1", + "user_email": "user1@example.com", + } + + mock_user_2 = MagicMock(spec=LiteLLM_UserTable) + mock_user_2.user_id = "user-2" + mock_user_2.user_email = "user2@example.com" + mock_user_2.model_dump.return_value = { + "user_id": "user-2", + "user_email": "user2@example.com", + } + + mock_updated_users = [mock_user_1, mock_user_2] + + mock_membership_1 = MagicMock(spec=LiteLLM_TeamMembership) + mock_membership_1.user_id = "user-1" + mock_membership_1.team_id = "test-team-123" + mock_membership_1.model_dump.return_value = { + "user_id": "user-1", + "team_id": "test-team-123", + } + + mock_membership_2 = MagicMock(spec=LiteLLM_TeamMembership) + mock_membership_2.user_id = "user-2" + mock_membership_2.team_id = "test-team-123" + mock_membership_2.model_dump.return_value = { + "user_id": "user-2", + "team_id": "test-team-123", + } + + mock_updated_memberships = [mock_membership_1, mock_membership_2] + + # Create a mock response that has model_dump method + mock_team_response = MagicMock() + mock_team_response.team_id = "test-team-123" + mock_team_response.team_alias = "Test Team" + mock_team_response.updated_users = mock_updated_users + mock_team_response.updated_team_memberships = mock_updated_memberships + mock_team_response.model_dump.return_value = { + "team_id": "test-team-123", + "team_alias": "Test Team", + "updated_users": [u.model_dump() for u in mock_updated_users], + "updated_team_memberships": [m.model_dump() for m in mock_updated_memberships], + } + + with patch( + "litellm.proxy.management_endpoints.team_endpoints.team_member_add", + new_callable=AsyncMock, + return_value=mock_team_response, + ) as mock_team_member_add: + + mock_auth = UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN) + + result = await bulk_team_member_add( + data=bulk_request, + user_api_key_dict=mock_auth, + ) + + # Verify the result structure + assert isinstance(result, BulkTeamMemberAddResponse) + assert result.team_id == "test-team-123" + assert result.total_requested == 2 + assert result.successful_additions == 2 + assert result.failed_additions == 0 + assert len(result.results) == 2 + + # Verify individual results + for i, member_result in enumerate(result.results): + assert isinstance(member_result, TeamMemberAddResult) + assert member_result.success is True + assert member_result.error is None + assert member_result.user_email == test_members[i].user_email + + # Verify team_member_add was called with correct data + mock_team_member_add.assert_called_once() + call_args = mock_team_member_add.call_args[1]["data"] + assert call_args.team_id == "test-team-123" + assert call_args.member == test_members + assert call_args.max_budget_in_team == 100.0 + + +@pytest.mark.asyncio +async def test_bulk_team_member_add_no_members_error(): + """ + Test bulk_team_member_add raises error when no members provided + """ + from litellm.proxy.management_endpoints.team_endpoints import bulk_team_member_add + + bulk_request = BulkTeamMemberAddRequest( + team_id="test-team-123", + members=[], # Empty list + ) + + mock_auth = UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN) + + with pytest.raises(HTTPException) as exc_info: + await bulk_team_member_add( + data=bulk_request, + user_api_key_dict=mock_auth, + ) + + assert exc_info.value.status_code == 400 + assert "At least one member is required" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +async def test_bulk_team_member_add_batch_size_limit(): + """ + Test bulk_team_member_add enforces maximum batch size limit + """ + from litellm.proxy.management_endpoints.team_endpoints import bulk_team_member_add + + # Create more than 500 members (the max batch size) + large_member_list = [ + Member(user_email=f"user{i}@example.com", role="user") for i in range(501) + ] + + bulk_request = BulkTeamMemberAddRequest( + team_id="test-team-123", + members=large_member_list, + ) + + mock_auth = UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN) + + with pytest.raises(HTTPException) as exc_info: + await bulk_team_member_add( + data=bulk_request, + user_api_key_dict=mock_auth, + ) + + assert exc_info.value.status_code == 400 + assert "Maximum 500 members can be added at once" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +async def test_bulk_team_member_add_all_users_flag(): + """ + Test bulk_team_member_add with all_users flag set to True + """ + from litellm.proxy._types import LiteLLM_UserTable, TeamAddMemberResponse + from litellm.proxy.management_endpoints.team_endpoints import bulk_team_member_add + + bulk_request = BulkTeamMemberAddRequest( + team_id="test-team-123", + all_users=True, + max_budget_in_team=50.0, + ) + + # Mock database users + mock_db_users = [ + MagicMock(user_id="user-1", user_email="user1@example.com"), + MagicMock(user_id="user-2", user_email="user2@example.com"), + ] + + mock_team_response = TeamAddMemberResponse( + team_id="test-team-123", + team_alias="Test Team", + updated_users=[], + updated_team_memberships=[], + ) + + with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch( + "litellm.proxy.management_endpoints.team_endpoints.team_member_add", + new_callable=AsyncMock, + return_value=mock_team_response, + ) as mock_team_member_add: + + # Mock the database find_many call + mock_prisma.db.litellm_usertable.find_many = AsyncMock( + return_value=mock_db_users + ) + + mock_auth = UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN) + + result = await bulk_team_member_add( + data=bulk_request, + user_api_key_dict=mock_auth, + ) + + # Verify that find_many was called to get all users + mock_prisma.db.litellm_usertable.find_many.assert_called_once_with( + order={"created_at": "desc"} + ) + + # Verify team_member_add was called with users from database + mock_team_member_add.assert_called_once() + call_args = mock_team_member_add.call_args[1]["data"] + assert call_args.team_id == "test-team-123" + assert len(call_args.member) == 2 # Should have 2 members from mock_db_users + assert call_args.max_budget_in_team == 50.0 + + +@pytest.mark.asyncio +async def test_bulk_team_member_add_failure_scenario(): + """ + Test bulk_team_member_add handles failures gracefully + """ + from litellm.proxy.management_endpoints.team_endpoints import bulk_team_member_add + + test_members = [ + Member(user_email="user1@example.com", role="user"), + Member(user_email="user2@example.com", role="admin"), + ] + + bulk_request = BulkTeamMemberAddRequest( + team_id="test-team-123", + members=test_members, + ) + + with patch( + "litellm.proxy.management_endpoints.team_endpoints.team_member_add", + new_callable=AsyncMock, + side_effect=Exception("Database connection failed"), + ) as mock_team_member_add: + + mock_auth = UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN) + + result = await bulk_team_member_add( + data=bulk_request, + user_api_key_dict=mock_auth, + ) + + # Verify failure response structure + assert isinstance(result, BulkTeamMemberAddResponse) + assert result.team_id == "test-team-123" + assert result.total_requested == 2 + assert result.successful_additions == 0 + assert result.failed_additions == 2 + assert result.updated_team is None + + # Verify all members marked as failed + assert len(result.results) == 2 + for member_result in result.results: + assert member_result.success is False + assert member_result.error == "Database connection failed" + + +@pytest.mark.asyncio +async def test_bulk_team_member_add_no_db_connection(): + """ + Test bulk_team_member_add handles missing database connection + """ + from litellm.proxy.management_endpoints.team_endpoints import bulk_team_member_add + + bulk_request = BulkTeamMemberAddRequest( + team_id="test-team-123", + members=[Member(user_email="user1@example.com", role="user")], + ) + + with patch("litellm.proxy.proxy_server.prisma_client", None): + mock_auth = UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN) + + with pytest.raises(HTTPException) as exc_info: + await bulk_team_member_add( + data=bulk_request, + user_api_key_dict=mock_auth, + ) + + assert exc_info.value.status_code == 500 + assert "DB not connected" in str(exc_info.value.detail) diff --git a/ui/litellm-dashboard/src/components/bulk_edit_user.tsx b/ui/litellm-dashboard/src/components/bulk_edit_user.tsx index 2ed4ef196..4c0be6eb7 100644 --- a/ui/litellm-dashboard/src/components/bulk_edit_user.tsx +++ b/ui/litellm-dashboard/src/components/bulk_edit_user.tsx @@ -14,7 +14,7 @@ import { Checkbox, } from "antd"; import { Button } from '@tremor/react'; -import { userBulkUpdateUserCall, teamBulkMemberAddCall } from "./networking"; +import { userBulkUpdateUserCall, teamBulkMemberAddCall, Member } from "./networking"; import { UserEditView } from "./user_edit_view"; const { Text, Title } = Typography; @@ -29,6 +29,7 @@ interface BulkEditUserModalProps { teams: any[] | null; userRole: string | null; userModels: string[]; + allowAllUsers?: boolean; // Optional flag to enable "all users" mode } const BulkEditUserModal: React.FC = ({ @@ -41,22 +42,25 @@ const BulkEditUserModal: React.FC = ({ teams, userRole, userModels, + allowAllUsers = false, }) => { const [loading, setLoading] = useState(false); const [selectedTeams, setSelectedTeams] = useState([]); const [teamBudget, setTeamBudget] = useState(null); const [addToTeams, setAddToTeams] = useState(false); + const [updateAllUsers, setUpdateAllUsers] = useState(false); const handleCancel = () => { // Reset team management state setSelectedTeams([]); setTeamBudget(null); setAddToTeams(false); + setUpdateAllUsers(false); onCancel(); }; // Create a mock userData object for the UserEditView - const mockUserData = { + const mockUserData = React.useMemo(() => ({ user_id: "bulk_edit", user_info: { user_email: "", @@ -71,9 +75,10 @@ const BulkEditUserModal: React.FC = ({ }, keys: [], teams: teams || [], - }; + }), [teams, visible]); const handleSubmit = async (formValues: any) => { + console.log("formValues", formValues); if (!accessToken) { message.error("Access token not found"); return; @@ -115,8 +120,13 @@ const BulkEditUserModal: React.FC = ({ // Handle user property updates if (hasUserUpdates) { - await userBulkUpdateUserCall(accessToken, updatePayload, userIds); - successMessages.push(`Updated ${userIds.length} user(s)`); + if (updateAllUsers) { + const result = await userBulkUpdateUserCall(accessToken, updatePayload, undefined, true); + successMessages.push(`Updated all users (${result.total_requested} total)`); + } else { + await userBulkUpdateUserCall(accessToken, updatePayload, userIds); + successMessages.push(`Updated ${userIds.length} user(s)`); + } } // Handle team additions @@ -126,18 +136,26 @@ const BulkEditUserModal: React.FC = ({ for (const teamId of selectedTeams) { try { // Create member objects for bulk add - const members = selectedUsers.map(user => ({ - user_id: user.user_id, - role: "user" as const, // Default role for bulk add - user_email: user.user_email || null, - })); + let members: Member[] | null = null; + if (updateAllUsers) { + members = null; + } else { + const members = selectedUsers.map(user => ({ + user_id: user.user_id, + role: "user" as const, // Default role for bulk add + user_email: user.user_email || null, + })); + } const result = await teamBulkMemberAddCall( accessToken, teamId, - members, - teamBudget || undefined + members ? members : null, + teamBudget || undefined, + updateAllUsers ); + + console.log("result", result); teamResults.push({ teamId, @@ -177,6 +195,7 @@ const BulkEditUserModal: React.FC = ({ setSelectedTeams([]); setTeamBudget(null); setAddToTeams(false); + setUpdateAllUsers(false); onSuccess(); onCancel(); @@ -193,11 +212,30 @@ const BulkEditUserModal: React.FC = ({ visible={visible} onCancel={handleCancel} footer={null} - title={`Bulk Edit ${selectedUsers.length} User(s)`} + title={updateAllUsers ? "Bulk Edit All Users" : `Bulk Edit ${selectedUsers.length} User(s)`} width={800} > -
- Selected Users ({selectedUsers.length}): + {allowAllUsers && ( +
+ setUpdateAllUsers(e.target.checked)} + > + Update ALL users in the system + + {updateAllUsers && ( +
+ + ⚠️ This will apply changes to ALL users in the system, not just the selected ones. + +
+ )} +
+ )} + + {!updateAllUsers && ( +
+ Selected Users ({selectedUsers.length}): = ({ }, ]} /> - + + )} @@ -334,7 +373,7 @@ const BulkEditUserModal: React.FC = ({ {loading && (
- Updating {selectedUsers.length} user(s)... + Updating {updateAllUsers ? "all users" : selectedUsers.length} user(s)...
)} diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index 9724b5ad4..2e45b0f9c 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -3700,8 +3700,9 @@ export const teamMemberAddCall = async ( export const teamBulkMemberAddCall = async ( accessToken: string, teamId: string, - members: Member[], - maxBudgetInTeam?: number + members: Member[] | null, + maxBudgetInTeam?: number, + allUsers?: boolean ) => { try { console.log("Bulk add team members:", { teamId, members, maxBudgetInTeam }); @@ -3710,11 +3711,16 @@ export const teamBulkMemberAddCall = async ( ? `${proxyBaseUrl}/team/bulk_member_add` : `/team/bulk_member_add`; - const requestBody: any = { + let requestBody: any = { team_id: teamId, - members: members, }; + if (allUsers) { + requestBody.all_users = true; + } else { + requestBody.members = members; + } + if (maxBudgetInTeam !== undefined && maxBudgetInTeam !== null) { requestBody.max_budget_in_team = maxBudgetInTeam; } @@ -4017,7 +4023,8 @@ export const userUpdateUserCall = async ( export const userBulkUpdateUserCall = async ( accessToken: string, formValues: any, // Assuming formValues is an object - userIds: string[] + userIds?: string[], // Optional - if not provided, will update all users + allUsers: boolean = false // Flag to update all users ) => { try { console.log("Form Values in userUpdateUserCall:", formValues); // Log the form values before making the API call @@ -4025,16 +4032,31 @@ export const userBulkUpdateUserCall = async ( const url = proxyBaseUrl ? `${proxyBaseUrl}/user/bulk_update` : `/user/bulk_update`; - let request_body = [] - for (const user_id of userIds) { - request_body.push({ - user_id: user_id, - ...formValues, + + let request_body_json: string; + + if (allUsers) { + // Update all users mode + request_body_json = JSON.stringify({ + all_users: true, + user_updates: formValues, }); + } else if (userIds && userIds.length > 0) { + // Update specific users mode + let request_body = [] + for (const user_id of userIds) { + request_body.push({ + user_id: user_id, + ...formValues, + }); + } + request_body_json = JSON.stringify({ + users: request_body, + }); + } else { + throw new Error("Must provide either userIds or set allUsers=true"); } - let request_body_json = JSON.stringify({ - users: request_body, - }); + const response = await fetch(url, { method: "POST", headers: { @@ -4052,8 +4074,16 @@ export const userBulkUpdateUserCall = async ( } const data = (await response.json()) as { - user_id: string; - data: UserInfo; + results: Array<{ + user_id?: string; + user_email?: string; + success: boolean; + error?: string; + updated_user?: any; + }>; + total_requested: number; + successful_updates: number; + failed_updates: number; }; console.log("API Response:", data); //message.success("User role updated"); diff --git a/ui/litellm-dashboard/src/components/user_edit_view.tsx b/ui/litellm-dashboard/src/components/user_edit_view.tsx index 431e6126d..c79d34309 100644 --- a/ui/litellm-dashboard/src/components/user_edit_view.tsx +++ b/ui/litellm-dashboard/src/components/user_edit_view.tsx @@ -132,6 +132,9 @@ export function UserEditView({ All Proxy Models + + No Default Models + {userModels.map((model) => ( {getModelDisplayName(model)} @@ -162,7 +165,7 @@ export function UserEditView({
-
) diff --git a/ui/litellm-dashboard/src/utils/roles.ts b/ui/litellm-dashboard/src/utils/roles.ts index 0bf928a6c..2038989ca 100644 --- a/ui/litellm-dashboard/src/utils/roles.ts +++ b/ui/litellm-dashboard/src/utils/roles.ts @@ -1,14 +1,13 @@ // Define admin roles and permissions -export const old_admin_roles = ["Admin", "Admin Viewer"]; -export const v2_admin_role_names = ["proxy_admin", "proxy_admin_viewer", "org_admin"]; -export const all_admin_roles = [...old_admin_roles, ...v2_admin_role_names]; +export const old_admin_roles = ["Admin", "Admin Viewer"] +export const v2_admin_role_names = ["proxy_admin", "proxy_admin_viewer", "org_admin"] +export const all_admin_roles = [...old_admin_roles, ...v2_admin_role_names] -export const internalUserRoles = ["Internal User", "Internal Viewer"]; -export const rolesAllowedToSeeUsage = ["Admin", "Admin Viewer", "Internal User", "Internal Viewer"]; -export const rolesWithWriteAccess = ["Internal User", "Admin"]; +export const internalUserRoles = ["Internal User", "Internal Viewer"] +export const rolesAllowedToSeeUsage = ["Admin", "Admin Viewer", "Internal User", "Internal Viewer"] +export const rolesWithWriteAccess = ["Internal User", "Admin"] // Helper function to check if a role is in all_admin_roles export const isAdminRole = (role: string): boolean => { - return all_admin_roles.includes(role); -}; - + return all_admin_roles.includes(role) +}