From dce6cd1051480f5c84f5f72385b4fb2ef5f806f0 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Thu, 9 Oct 2025 22:18:05 +0530 Subject: [PATCH] Add shared healthcheck --- docs/my-website/docs/proxy/health.md | 21 +- .../docs/proxy/shared_health_check.md | 310 ++++++++++++++ litellm/constants.py | 6 + litellm/proxy/health_check_utils/__init__.py | 1 + .../shared_health_check_manager.py | 329 ++++++++++++++ .../health_endpoints/_health_endpoints.py | 46 ++ litellm/proxy/proxy_server.py | 46 +- proxy_server_config.yaml | 8 + .../proxy/test_shared_health_check.py | 403 ++++++++++++++++++ 9 files changed, 1163 insertions(+), 7 deletions(-) create mode 100644 docs/my-website/docs/proxy/shared_health_check.md create mode 100644 litellm/proxy/health_check_utils/__init__.py create mode 100644 litellm/proxy/health_check_utils/shared_health_check_manager.py create mode 100644 tests/test_litellm/proxy/test_shared_health_check.py diff --git a/docs/my-website/docs/proxy/health.md b/docs/my-website/docs/proxy/health.md index 7e627846b..c96753648 100644 --- a/docs/my-website/docs/proxy/health.md +++ b/docs/my-website/docs/proxy/health.md @@ -9,13 +9,32 @@ Use this to health check all LLMs defined in your config.yaml | `/health/readiness` | **Load balancer health checks** | Ready to accept traffic - includes DB connection status | | `/health` | **Model health monitoring** | Comprehensive LLM model health - makes actual API calls | | `/health/services` | **Service debugging** | Check specific integrations (datadog, langfuse, etc.) | +| `/health/shared-status` | **Multi-pod coordination** | Monitor shared health check state across pods | ## Summary The proxy exposes: * a /health endpoint which returns the health of the LLM APIs * a /health/readiness endpoint for returning if the proxy is ready to accept requests -* a /health/liveliness endpoint for returning if the proxy is alive +* a /health/liveliness endpoint for returning if the proxy is alive +* a /health/shared-status endpoint for monitoring shared health check coordination across pods + +## Shared Health Check State + +When running multiple LiteLLM proxy pods, you can enable shared health check state to coordinate health checks across pods and avoid duplicate API calls. This is especially beneficial for expensive models like Gemini 2.5-pro. + +**Key Benefits:** +- Reduces duplicate health checks across pods +- Saves costs on expensive model API calls +- Reduces monitoring noise and logging +- Improves resource efficiency + +**Requirements:** +- Redis for shared state coordination +- Background health checks enabled +- Multiple proxy pods + +For detailed configuration and usage, see [Shared Health Check State](./shared_health_check.md). ## `/health` #### Request diff --git a/docs/my-website/docs/proxy/shared_health_check.md b/docs/my-website/docs/proxy/shared_health_check.md new file mode 100644 index 000000000..d4b701163 --- /dev/null +++ b/docs/my-website/docs/proxy/shared_health_check.md @@ -0,0 +1,310 @@ +# Shared Health Check State Across Pods + +This feature enables coordination of health checks across multiple LiteLLM proxy pods to avoid duplicate health checks and reduce costs. + +## Overview + +When running multiple LiteLLM proxy pods (e.g., in Kubernetes), each pod typically runs its own independent health checks on every model. This can result in: + +- **Duplicate health checks** across pods +- **Increased costs** for expensive models (e.g., Gemini 2.5-pro) +- **Redundant monitoring/logging noise** +- **Inefficient resource usage** + +The shared health check state feature solves this by: + +- **Coordinating health checks** across pods using Redis +- **Caching results** with configurable TTL +- **Using distributed locks** to ensure only one pod runs health checks at a time +- **Allowing other pods** to read cached results instead of running redundant checks + +## How It Works + +### 1. Lock Acquisition +When a pod needs to run health checks: +- It attempts to acquire a Redis lock +- If successful, it runs the health checks +- If failed, it waits briefly and checks for cached results + +### 2. Result Caching +After running health checks: +- Results are cached in Redis with a configurable TTL +- Other pods can read these cached results +- Cache includes timestamp and pod ID for tracking + +### 3. Fallback Behavior +If Redis is unavailable or cache is expired: +- Pods fall back to running health checks locally +- System continues to function normally + +## Configuration + +### Enable Shared Health Check + +Add to your `proxy_config.yaml`: + +```yaml +general_settings: + # Enable background health checks (required) + background_health_checks: true + + # Enable shared health check state across pods + use_shared_health_check: true + + # Health check interval (seconds) + health_check_interval: 300 # 5 minutes + +# Redis configuration (required for shared health check) +litellm_settings: + cache: true + cache_params: + type: redis + host: your-redis-host + port: 6379 + password: your-redis-password +``` + +### Environment Variables + +You can also configure using environment variables: + +```bash +# Enable shared health check +export USE_SHARED_HEALTH_CHECK=true + +# Health check TTL (seconds) +export DEFAULT_SHARED_HEALTH_CHECK_TTL=300 + +# Lock TTL (seconds) +export DEFAULT_SHARED_HEALTH_CHECK_LOCK_TTL=60 +``` + +## Requirements + +- **Redis**: Required for shared state coordination +- **Background Health Checks**: Must be enabled (`background_health_checks: true`) +- **Multiple Pods**: Most beneficial with 2+ proxy instances + +## API Endpoints + +### Check Shared Health Check Status + +```bash +GET /health/shared-status +``` + +Returns information about the shared health check coordination: + +```json +{ + "shared_health_check_enabled": true, + "status": { + "pod_id": "pod_1703123456789", + "redis_available": true, + "lock_ttl": 60, + "cache_ttl": 300, + "lock_owner": "pod_1703123456788", + "lock_in_progress": true, + "cache_available": true, + "cache_age_seconds": 45.2, + "last_checked_by": "pod_1703123456788" + } +} +``` + +## Monitoring + +### Health Check Status + +Monitor the shared health check status to ensure proper coordination: + +```bash +curl -H "Authorization: Bearer your-api-key" \ + http://your-proxy-host/health/shared-status +``` + +### Logs + +Look for these log messages: + +``` +INFO: Initialized shared health check manager +INFO: Pod pod_123 acquired health check lock +INFO: Pod pod_123 released health check lock +INFO: Cached health check results for 5 healthy and 0 unhealthy endpoints +DEBUG: Using cached health check results +``` + +## Troubleshooting + +### Common Issues + +#### 1. Shared Health Check Not Working + +**Symptoms**: Each pod still runs independent health checks + +**Solutions**: +- Verify Redis is configured and accessible +- Check that `use_shared_health_check: true` is set +- Ensure `background_health_checks: true` is enabled +- Check Redis connectivity in logs + +#### 2. Redis Connection Issues + +**Symptoms**: Health checks fall back to local execution + +**Solutions**: +- Verify Redis host, port, and credentials +- Check network connectivity between pods and Redis +- Monitor Redis server logs for errors + +#### 3. Lock Not Released + +**Symptoms**: One pod holds the lock indefinitely + +**Solutions**: +- Lock has automatic TTL (default 60 seconds) +- Check pod logs for lock release messages +- Verify Redis TTL settings + +### Debug Mode + +Enable debug logging to see detailed coordination: + +```yaml +general_settings: + set_verbose: true +``` + +## Performance Impact + +### Benefits + +- **Reduced API calls**: Only one pod runs health checks per interval +- **Lower costs**: Especially significant for expensive models +- **Better resource utilization**: Less redundant work across pods +- **Cleaner monitoring**: Reduced noise in logs and metrics + +### Overhead + +- **Redis operations**: Minimal overhead for lock/cache operations +- **Network latency**: Small delay for Redis communication +- **Memory usage**: Negligible additional memory usage + +## Best Practices + +### 1. Redis Configuration + +- Use Redis with persistence enabled +- Configure appropriate memory limits +- Set up Redis monitoring and alerts + +### 2. TTL Settings + +- Set `health_check_interval` to your desired check frequency +- Use default TTL values unless you have specific requirements +- Consider model-specific timeouts for expensive models + +### 3. Monitoring + +- Monitor shared health check status endpoint +- Set up alerts for Redis connectivity issues +- Track health check costs and frequency + +### 4. Scaling + +- Feature works with any number of pods +- More pods = better coordination benefits +- Consider Redis cluster for high availability + +## Example Configuration + +### Complete Example + +```yaml +# proxy_config.yaml +model_list: + - model_name: gpt-4 + litellm_params: + model: gpt-4 + api_key: os.environ/OPENAI_API_KEY + model_info: + health_check_timeout: 30 # 30 second timeout for health checks + +general_settings: + # Enable background health checks + background_health_checks: true + + # Enable shared health check coordination + use_shared_health_check: true + + # Health check interval (5 minutes) + health_check_interval: 300 + + # Health check details + health_check_details: true + +litellm_settings: + # Redis configuration + cache: true + cache_params: + type: redis + host: redis-cluster.example.com + port: 6379 + password: os.environ/REDIS_PASSWORD + ssl: true +``` + +### Kubernetes Example + +```yaml +# deployment.yaml +apiVersion: apps/v1 +kind: Deployment +metadata: + name: litellm-proxy +spec: + replicas: 3 # Multiple pods for coordination + template: + spec: + containers: + - name: litellm-proxy + image: ghcr.io/berriai/litellm:latest + env: + - name: USE_SHARED_HEALTH_CHECK + value: "true" + - name: REDIS_HOST + value: "redis-service" + - name: REDIS_PASSWORD + valueFrom: + secretKeyRef: + name: redis-secret + key: password +``` + +## Migration + +### From Independent Health Checks + +1. **Enable Redis**: Ensure Redis is configured and accessible +2. **Enable Background Health Checks**: Set `background_health_checks: true` +3. **Enable Shared Health Check**: Set `use_shared_health_check: true` +4. **Deploy**: Update your proxy configuration +5. **Monitor**: Check `/health/shared-status` endpoint + +### Rollback + +To disable shared health check: + +```yaml +general_settings: + use_shared_health_check: false + # background_health_checks can remain true for independent checks +``` + +## Related Features + +- [Background Health Checks](./health.md#background-health-checks) +- [Redis Caching](./caching.md) +- [High Availability Setup](./db_deadlocks.md) +- [Health Check Endpoints](./health.md#health-endpoints) diff --git a/litellm/constants.py b/litellm/constants.py index 7fe0e69a4..697a6fb65 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -1028,6 +1028,12 @@ PROXY_BATCH_WRITE_AT = int(os.getenv("PROXY_BATCH_WRITE_AT", 10)) # in seconds DEFAULT_HEALTH_CHECK_INTERVAL = int( os.getenv("DEFAULT_HEALTH_CHECK_INTERVAL", 300) ) # 5 minutes +DEFAULT_SHARED_HEALTH_CHECK_TTL = int( + os.getenv("DEFAULT_SHARED_HEALTH_CHECK_TTL", 300) +) # 5 minutes - TTL for cached health check results +DEFAULT_SHARED_HEALTH_CHECK_LOCK_TTL = int( + os.getenv("DEFAULT_SHARED_HEALTH_CHECK_LOCK_TTL", 60) +) # 1 minute - TTL for health check lock PROMETHEUS_FALLBACK_STATS_SEND_TIME_HOURS = int( os.getenv("PROMETHEUS_FALLBACK_STATS_SEND_TIME_HOURS", 9) ) diff --git a/litellm/proxy/health_check_utils/__init__.py b/litellm/proxy/health_check_utils/__init__.py new file mode 100644 index 000000000..fe47bae6e --- /dev/null +++ b/litellm/proxy/health_check_utils/__init__.py @@ -0,0 +1 @@ +# Health check package diff --git a/litellm/proxy/health_check_utils/shared_health_check_manager.py b/litellm/proxy/health_check_utils/shared_health_check_manager.py new file mode 100644 index 000000000..98e0edb68 --- /dev/null +++ b/litellm/proxy/health_check_utils/shared_health_check_manager.py @@ -0,0 +1,329 @@ +import asyncio +import json +import time +from typing import Any, Dict, List, Optional, Tuple, Union + +from litellm._logging import verbose_proxy_logger +from litellm.caching.redis_cache import RedisCache +from litellm.constants import ( + DEFAULT_HEALTH_CHECK_INTERVAL, + DEFAULT_SHARED_HEALTH_CHECK_TTL, + DEFAULT_SHARED_HEALTH_CHECK_LOCK_TTL, +) +from litellm.proxy.health_check import perform_health_check + + +class SharedHealthCheckManager: + """ + Manager for coordinating health checks across multiple pods using Redis. + + This class implements a shared health check state mechanism that: + - Prevents duplicate health checks across pods + - Caches health check results with configurable TTL + - Uses Redis locks to ensure only one pod runs health checks at a time + - Allows other pods to read cached results instead of running redundant checks + """ + + def __init__( + self, + redis_cache: Optional[RedisCache] = None, + health_check_ttl: int = DEFAULT_SHARED_HEALTH_CHECK_TTL, + lock_ttl: int = DEFAULT_SHARED_HEALTH_CHECK_LOCK_TTL, + ): + self.redis_cache = redis_cache + self.health_check_ttl = health_check_ttl + self.lock_ttl = lock_ttl + self.pod_id = f"pod_{int(time.time() * 1000)}" + + @staticmethod + def get_health_check_lock_key() -> str: + """Get the Redis key for health check lock.""" + return "health_check_lock" + + @staticmethod + def get_health_check_cache_key() -> str: + """Get the Redis key for health check results cache.""" + return "health_check_results" + + @staticmethod + def get_model_health_check_lock_key(model_name: str) -> str: + """Get the Redis key for model-specific health check lock.""" + return f"health_check_lock:{model_name}" + + @staticmethod + def get_model_health_check_cache_key(model_name: str) -> str: + """Get the Redis key for model-specific health check results cache.""" + return f"health_check_results:{model_name}" + + async def acquire_health_check_lock(self) -> bool: + """ + Attempt to acquire the global health check lock. + + Returns: + bool: True if lock was acquired, False otherwise + """ + if self.redis_cache is None: + verbose_proxy_logger.debug("redis_cache is None, skipping lock acquisition") + return False + + try: + lock_key = self.get_health_check_lock_key() + acquired = await self.redis_cache.async_set_cache( + lock_key, + self.pod_id, + nx=True, # Only set if key doesn't exist + ttl=self.lock_ttl, + ) + + if acquired: + verbose_proxy_logger.info( + "Pod %s acquired health check lock", self.pod_id + ) + else: + verbose_proxy_logger.debug( + "Pod %s failed to acquire health check lock", self.pod_id + ) + + return acquired + except Exception as e: + verbose_proxy_logger.error( + "Error acquiring health check lock: %s", str(e) + ) + return False + + async def release_health_check_lock(self) -> None: + """Release the global health check lock.""" + if self.redis_cache is None: + return + + try: + lock_key = self.get_health_check_lock_key() + # Only release if we own the lock + current_owner = await self.redis_cache.async_get_cache(lock_key) + if current_owner == self.pod_id: + await self.redis_cache.async_delete_cache(lock_key) + verbose_proxy_logger.info( + "Pod %s released health check lock", self.pod_id + ) + except Exception as e: + verbose_proxy_logger.error( + "Error releasing health check lock: %s", str(e) + ) + + async def get_cached_health_check_results(self) -> Optional[Dict[str, Any]]: + """ + Get cached health check results from Redis. + + Returns: + Optional[Dict]: Cached health check results or None if not found/expired + """ + if self.redis_cache is None: + return None + + try: + cache_key = self.get_health_check_cache_key() + cached_data = await self.redis_cache.async_get_cache(cache_key) + + if cached_data is None: + return None + + # Parse the cached data + if isinstance(cached_data, str): + cached_results = json.loads(cached_data) + else: + cached_results = cached_data + + # Check if the cache is still valid + cache_timestamp = cached_results.get("timestamp", 0) + current_time = time.time() + + if current_time - cache_timestamp > self.health_check_ttl: + verbose_proxy_logger.debug("Cached health check results expired") + return None + + verbose_proxy_logger.debug("Using cached health check results") + return cached_results + + except Exception as e: + verbose_proxy_logger.error( + "Error getting cached health check results: %s", str(e) + ) + return None + + async def cache_health_check_results( + self, + healthy_endpoints: List[Dict[str, Any]], + unhealthy_endpoints: List[Dict[str, Any]] + ) -> None: + """ + Cache health check results in Redis. + + Args: + healthy_endpoints: List of healthy endpoints + unhealthy_endpoints: List of unhealthy endpoints + """ + if self.redis_cache is None: + return + + try: + cache_data = { + "healthy_endpoints": healthy_endpoints, + "unhealthy_endpoints": unhealthy_endpoints, + "healthy_count": len(healthy_endpoints), + "unhealthy_count": len(unhealthy_endpoints), + "timestamp": time.time(), + "checked_by": self.pod_id, + } + + cache_key = self.get_health_check_cache_key() + await self.redis_cache.async_set_cache( + cache_key, + json.dumps(cache_data), + ttl=self.health_check_ttl, + ) + + verbose_proxy_logger.info( + "Cached health check results for %d healthy and %d unhealthy endpoints", + len(healthy_endpoints), + len(unhealthy_endpoints), + ) + + except Exception as e: + verbose_proxy_logger.error( + "Error caching health check results: %s", str(e) + ) + + async def perform_shared_health_check( + self, + model_list: List[Dict[str, Any]], + details: bool = True + ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + """ + Perform health check with shared state coordination. + + This method: + 1. First checks if there are recent cached results + 2. If no recent cache, tries to acquire lock to run health check + 3. If lock acquired, runs health check and caches results + 4. If lock not acquired, waits briefly and tries to get cached results again + 5. Falls back to running health check locally if no cache available + + Args: + model_list: List of models to check + details: Whether to include detailed information + + Returns: + Tuple of (healthy_endpoints, unhealthy_endpoints) + """ + # First, try to get cached results + cached_results = await self.get_cached_health_check_results() + if cached_results is not None: + return ( + cached_results.get("healthy_endpoints", []), + cached_results.get("unhealthy_endpoints", []), + ) + + # No recent cache, try to acquire lock + lock_acquired = await self.acquire_health_check_lock() + + if lock_acquired: + try: + # We have the lock, run health check + verbose_proxy_logger.info( + "Pod %s running health check for %d models", + self.pod_id, + len(model_list) + ) + + healthy_endpoints, unhealthy_endpoints = await perform_health_check( + model_list=model_list, details=details + ) + + # Cache the results + await self.cache_health_check_results( + healthy_endpoints, unhealthy_endpoints + ) + + return healthy_endpoints, unhealthy_endpoints + + finally: + # Always release the lock + await self.release_health_check_lock() + else: + # Lock not acquired, wait briefly and try to get cached results + verbose_proxy_logger.debug( + "Pod %s waiting for other pod to complete health check", self.pod_id + ) + + # Wait a bit for the other pod to complete + await asyncio.sleep(2) + + # Try to get cached results again + cached_results = await self.get_cached_health_check_results() + if cached_results is not None: + return ( + cached_results.get("healthy_endpoints", []), + cached_results.get("unhealthy_endpoints", []), + ) + + # Still no cache, fall back to local health check + verbose_proxy_logger.warning( + "Pod %s falling back to local health check (no cache available)", + self.pod_id + ) + + return await perform_health_check(model_list=model_list, details=details) + + async def is_health_check_in_progress(self) -> bool: + """ + Check if a health check is currently in progress by another pod. + + Returns: + bool: True if health check is in progress, False otherwise + """ + if self.redis_cache is None: + return False + + try: + lock_key = self.get_health_check_lock_key() + current_owner = await self.redis_cache.async_get_cache(lock_key) + return current_owner is not None and current_owner != self.pod_id + except Exception as e: + verbose_proxy_logger.error( + "Error checking health check lock status: %s", str(e) + ) + return False + + async def get_health_check_status(self) -> Dict[str, Any]: + """ + Get the current status of health check coordination. + + Returns: + Dict containing status information + """ + status = { + "pod_id": self.pod_id, + "redis_available": self.redis_cache is not None, + "lock_ttl": self.lock_ttl, + "cache_ttl": self.health_check_ttl, + } + + if self.redis_cache is not None: + try: + # Check if there's a current lock + lock_key = self.get_health_check_lock_key() + current_owner = await self.redis_cache.async_get_cache(lock_key) + status["lock_owner"] = current_owner + status["lock_in_progress"] = current_owner is not None + + # Check cache status + cached_results = await self.get_cached_health_check_results() + status["cache_available"] = cached_results is not None + if cached_results: + status["cache_age_seconds"] = time.time() - cached_results.get("timestamp", 0) + status["last_checked_by"] = cached_results.get("checked_by") + + except Exception as e: + status["error"] = str(e) + + return status diff --git a/litellm/proxy/health_endpoints/_health_endpoints.py b/litellm/proxy/health_endpoints/_health_endpoints.py index 883bff318..6bb9d893a 100644 --- a/litellm/proxy/health_endpoints/_health_endpoints.py +++ b/litellm/proxy/health_endpoints/_health_endpoints.py @@ -608,6 +608,52 @@ async def latest_health_checks_endpoint( ) +@router.get( + "/health/shared-status", tags=["health"], dependencies=[Depends(user_api_key_auth)] +) +async def shared_health_check_status_endpoint( + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Get the status of shared health check coordination across pods. + + Returns information about Redis connectivity, lock status, and cache status. + """ + from litellm.proxy.proxy_server import use_shared_health_check, redis_usage_cache + + if not use_shared_health_check: + return { + "shared_health_check_enabled": False, + "message": "Shared health check is not enabled" + } + + if redis_usage_cache is None: + return { + "shared_health_check_enabled": True, + "redis_available": False, + "message": "Redis is not configured" + } + + try: + from litellm.proxy.health_check_utils.shared_health_check_manager import SharedHealthCheckManager + + shared_health_manager = SharedHealthCheckManager( + redis_cache=redis_usage_cache, + ) + + status = await shared_health_manager.get_health_check_status() + return { + "shared_health_check_enabled": True, + "status": status + } + except Exception as e: + verbose_proxy_logger.error(f"Error getting shared health check status: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={"error": f"Failed to retrieve shared health check status: {str(e)}"}, + ) + + db_health_cache = {"status": "unknown", "last_updated": datetime.now()} diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 366d3f57d..03cfe2e0d 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -31,6 +31,8 @@ from litellm.constants import ( BASE_MCP_ROUTE, DEFAULT_MAX_RECURSE_DEPTH, DEFAULT_SLACK_ALERTING_THRESHOLD, + DEFAULT_SHARED_HEALTH_CHECK_TTL, + DEFAULT_SHARED_HEALTH_CHECK_LOCK_TTL, LITELLM_EMBEDDING_PROVIDERS_SUPPORTING_INPUT_ARRAY_OF_TOKENS, LITELLM_SETTINGS_SAFE_DB_OVERRIDES, ) @@ -510,7 +512,7 @@ _description = ( def cleanup_router_config_variables(): - global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, user_custom_sso, user_custom_ui_sso_sign_in_handler, use_background_health_checks, health_check_interval, prisma_client + global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, user_custom_sso, user_custom_ui_sso_sign_in_handler, use_background_health_checks, use_shared_health_check, health_check_interval, prisma_client # Set all variables to None master_key = None @@ -522,6 +524,7 @@ def cleanup_router_config_variables(): user_custom_sso = None user_custom_ui_sso_sign_in_handler = None use_background_health_checks = None + use_shared_health_check = None health_check_interval = None prisma_client = None @@ -970,6 +973,7 @@ user_custom_key_generate = None user_custom_sso = None user_custom_ui_sso_sign_in_handler = None use_background_health_checks = None +use_shared_health_check = None use_queue = False health_check_interval = None health_check_details = None @@ -1369,8 +1373,9 @@ async def _run_background_health_check(): Periodically run health checks in the background on the endpoints. Update health_check_results, based on this. + Uses shared health check state when Redis is available to coordinate across pods. """ - global health_check_results, llm_model_list, health_check_interval, health_check_details + global health_check_results, llm_model_list, health_check_interval, health_check_details, use_shared_health_check, redis_usage_cache if ( health_check_interval is None @@ -1379,6 +1384,17 @@ async def _run_background_health_check(): ): return + # Initialize shared health check manager if Redis is available and feature is enabled + shared_health_manager = None + if use_shared_health_check and redis_usage_cache is not None: + from litellm.proxy.health_check_utils.shared_health_check_manager import SharedHealthCheckManager + shared_health_manager = SharedHealthCheckManager( + redis_cache=redis_usage_cache, + health_check_ttl=DEFAULT_SHARED_HEALTH_CHECK_TTL, + lock_ttl=DEFAULT_SHARED_HEALTH_CHECK_LOCK_TTL, + ) + verbose_proxy_logger.info("Initialized shared health check manager") + while True: # make 1 deep copy of llm_model_list on every health check iteration _llm_model_list = copy.deepcopy(llm_model_list) or [] @@ -1390,9 +1406,23 @@ async def _run_background_health_check(): if not m.get("model_info", {}).get("disable_background_health_check", False) ] - healthy_endpoints, unhealthy_endpoints = await perform_health_check( - model_list=_llm_model_list, details=health_check_details - ) + # Use shared health check if available, otherwise fall back to direct health check + if shared_health_manager is not None: + try: + healthy_endpoints, unhealthy_endpoints = await shared_health_manager.perform_shared_health_check( + model_list=_llm_model_list, details=health_check_details + ) + except Exception as e: + verbose_proxy_logger.error( + "Error in shared health check, falling back to direct health check: %s", str(e) + ) + healthy_endpoints, unhealthy_endpoints = await perform_health_check( + model_list=_llm_model_list, details=health_check_details + ) + else: + healthy_endpoints, unhealthy_endpoints = await perform_health_check( + model_list=_llm_model_list, details=health_check_details + ) # Update the global variable with the health check results health_check_results["healthy_endpoints"] = healthy_endpoints @@ -1752,7 +1782,7 @@ class ProxyConfig: """ Load config values into proxy global state """ - global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, user_custom_sso, user_custom_ui_sso_sign_in_handler, use_background_health_checks, health_check_interval, use_queue, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger, health_check_details, callback_settings, proxy_batch_polling_interval + global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, user_custom_sso, user_custom_ui_sso_sign_in_handler, use_background_health_checks, use_shared_health_check, health_check_interval, use_queue, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger, health_check_details, callback_settings, proxy_batch_polling_interval config: dict = await self.get_config(config_file_path=config_file_path) @@ -2144,6 +2174,10 @@ class ProxyConfig: use_background_health_checks = general_settings.get( "background_health_checks", False ) + # Enable shared health check state across pods (requires Redis) + use_shared_health_check = general_settings.get( + "use_shared_health_check", False + ) health_check_interval = general_settings.get( "health_check_interval", DEFAULT_HEALTH_CHECK_INTERVAL ) diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index f78656811..8f125ac03 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -168,6 +168,11 @@ litellm_settings: langfuse_public_key: os.environ/LANGFUSE_PROJECT2_PUBLIC # Project 2 langfuse_secret: os.environ/LANGFUSE_PROJECT2_SECRET # Project 2 langfuse_host: https://us.cloud.langfuse.com + # cache: true # [OPTIONAL] use for caching responses + # cache_params: # And for shared health check + # type: redis + # host: localhost + # port: 6379 # For /fine_tuning/jobs endpoints finetune_settings: @@ -202,6 +207,9 @@ general_settings: proxy_budget_rescheduler_max_time: 64 proxy_batch_write_at: 1 database_connection_pool_limit: 10 + # background_health_checks: true + # use_shared_health_check: true + # health_check_interval: 30 # database_url: "postgresql://:@:/" # [OPTIONAL] use for token-based auth to proxy pass_through_endpoints: diff --git a/tests/test_litellm/proxy/test_shared_health_check.py b/tests/test_litellm/proxy/test_shared_health_check.py new file mode 100644 index 000000000..82deebc42 --- /dev/null +++ b/tests/test_litellm/proxy/test_shared_health_check.py @@ -0,0 +1,403 @@ +import asyncio +import json +import pytest +import time +from unittest.mock import AsyncMock, MagicMock, patch + +from litellm.proxy.health_check_utils.shared_health_check_manager import SharedHealthCheckManager + + +class TestSharedHealthCheckManager: + """Test cases for SharedHealthCheckManager""" + + @pytest.fixture + def mock_redis_cache(self): + """Mock Redis cache for testing""" + cache = AsyncMock() + cache.async_set_cache = AsyncMock() + cache.async_get_cache = AsyncMock() + cache.async_delete_cache = AsyncMock() + return cache + + @pytest.fixture + def shared_health_manager(self, mock_redis_cache): + """Create SharedHealthCheckManager instance for testing""" + return SharedHealthCheckManager( + redis_cache=mock_redis_cache, + health_check_ttl=300, + lock_ttl=60, + ) + + def test_initialization(self, mock_redis_cache): + """Test SharedHealthCheckManager initialization""" + manager = SharedHealthCheckManager( + redis_cache=mock_redis_cache, + health_check_ttl=300, + lock_ttl=60, + ) + + assert manager.redis_cache == mock_redis_cache + assert manager.health_check_ttl == 300 + assert manager.lock_ttl == 60 + assert manager.pod_id.startswith("pod_") + + def test_initialization_without_redis(self): + """Test SharedHealthCheckManager initialization without Redis""" + manager = SharedHealthCheckManager(redis_cache=None) + + assert manager.redis_cache is None + assert manager.health_check_ttl == 300 # Default value + assert manager.lock_ttl == 60 # Default value + + def test_get_health_check_lock_key(self): + """Test getting health check lock key""" + key = SharedHealthCheckManager.get_health_check_lock_key() + assert key == "health_check_lock" + + def test_get_health_check_cache_key(self): + """Test getting health check cache key""" + key = SharedHealthCheckManager.get_health_check_cache_key() + assert key == "health_check_results" + + def test_get_model_health_check_lock_key(self): + """Test getting model-specific health check lock key""" + key = SharedHealthCheckManager.get_model_health_check_lock_key("test-model") + assert key == "health_check_lock:test-model" + + def test_get_model_health_check_cache_key(self): + """Test getting model-specific health check cache key""" + key = SharedHealthCheckManager.get_model_health_check_cache_key("test-model") + assert key == "health_check_results:test-model" + + @pytest.mark.asyncio + async def test_acquire_health_check_lock_success(self, shared_health_manager, mock_redis_cache): + """Test successful lock acquisition""" + mock_redis_cache.async_set_cache.return_value = True + + result = await shared_health_manager.acquire_health_check_lock() + + assert result is True + mock_redis_cache.async_set_cache.assert_called_once_with( + "health_check_lock", + shared_health_manager.pod_id, + nx=True, + ttl=60, + ) + + @pytest.mark.asyncio + async def test_acquire_health_check_lock_failure(self, shared_health_manager, mock_redis_cache): + """Test failed lock acquisition""" + mock_redis_cache.async_set_cache.return_value = False + + result = await shared_health_manager.acquire_health_check_lock() + + assert result is False + + @pytest.mark.asyncio + async def test_acquire_health_check_lock_no_redis(self): + """Test lock acquisition without Redis""" + manager = SharedHealthCheckManager(redis_cache=None) + + result = await manager.acquire_health_check_lock() + + assert result is False + + @pytest.mark.asyncio + async def test_acquire_health_check_lock_exception(self, shared_health_manager, mock_redis_cache): + """Test lock acquisition with exception""" + mock_redis_cache.async_set_cache.side_effect = Exception("Redis error") + + result = await shared_health_manager.acquire_health_check_lock() + + assert result is False + + @pytest.mark.asyncio + async def test_release_health_check_lock_success(self, shared_health_manager, mock_redis_cache): + """Test successful lock release""" + mock_redis_cache.async_get_cache.return_value = shared_health_manager.pod_id + + await shared_health_manager.release_health_check_lock() + + mock_redis_cache.async_get_cache.assert_called_once_with("health_check_lock") + mock_redis_cache.async_delete_cache.assert_called_once_with("health_check_lock") + + @pytest.mark.asyncio + async def test_release_health_check_lock_wrong_owner(self, shared_health_manager, mock_redis_cache): + """Test lock release when not the owner""" + mock_redis_cache.async_get_cache.return_value = "other_pod_id" + + await shared_health_manager.release_health_check_lock() + + mock_redis_cache.async_get_cache.assert_called_once_with("health_check_lock") + mock_redis_cache.async_delete_cache.assert_not_called() + + @pytest.mark.asyncio + async def test_release_health_check_lock_no_redis(self): + """Test lock release without Redis""" + manager = SharedHealthCheckManager(redis_cache=None) + + # Should not raise exception + await manager.release_health_check_lock() + + @pytest.mark.asyncio + async def test_get_cached_health_check_results_success(self, shared_health_manager, mock_redis_cache): + """Test getting cached health check results successfully""" + current_time = time.time() + cached_data = { + "healthy_endpoints": [{"model": "test-model"}], + "unhealthy_endpoints": [], + "healthy_count": 1, + "unhealthy_count": 0, + "timestamp": current_time - 100, # 100 seconds ago + "checked_by": "test_pod", + } + mock_redis_cache.async_get_cache.return_value = json.dumps(cached_data) + + result = await shared_health_manager.get_cached_health_check_results() + + assert result is not None + assert result["healthy_count"] == 1 + assert result["unhealthy_count"] == 0 + + @pytest.mark.asyncio + async def test_get_cached_health_check_results_expired(self, shared_health_manager, mock_redis_cache): + """Test getting expired cached health check results""" + current_time = time.time() + cached_data = { + "healthy_endpoints": [{"model": "test-model"}], + "unhealthy_endpoints": [], + "healthy_count": 1, + "unhealthy_count": 0, + "timestamp": current_time - 400, # 400 seconds ago (expired) + "checked_by": "test_pod", + } + mock_redis_cache.async_get_cache.return_value = json.dumps(cached_data) + + result = await shared_health_manager.get_cached_health_check_results() + + assert result is None + + @pytest.mark.asyncio + async def test_get_cached_health_check_results_no_cache(self, shared_health_manager, mock_redis_cache): + """Test getting cached results when no cache exists""" + mock_redis_cache.async_get_cache.return_value = None + + result = await shared_health_manager.get_cached_health_check_results() + + assert result is None + + @pytest.mark.asyncio + async def test_get_cached_health_check_results_no_redis(self): + """Test getting cached results without Redis""" + manager = SharedHealthCheckManager(redis_cache=None) + + result = await manager.get_cached_health_check_results() + + assert result is None + + @pytest.mark.asyncio + async def test_cache_health_check_results_success(self, shared_health_manager, mock_redis_cache): + """Test caching health check results successfully""" + healthy_endpoints = [{"model": "test-model-1"}] + unhealthy_endpoints = [{"model": "test-model-2"}] + + await shared_health_manager.cache_health_check_results( + healthy_endpoints, unhealthy_endpoints + ) + + mock_redis_cache.async_set_cache.assert_called_once() + call_args = mock_redis_cache.async_set_cache.call_args + assert call_args[0][0] == "health_check_results" # key + assert call_args[1]["ttl"] == 300 # ttl + + # Verify cached data structure + cached_data = json.loads(call_args[0][1]) + assert cached_data["healthy_endpoints"] == healthy_endpoints + assert cached_data["unhealthy_endpoints"] == unhealthy_endpoints + assert cached_data["healthy_count"] == 1 + assert cached_data["unhealthy_count"] == 1 + assert "timestamp" in cached_data + assert cached_data["checked_by"] == shared_health_manager.pod_id + + @pytest.mark.asyncio + async def test_cache_health_check_results_no_redis(self): + """Test caching results without Redis""" + manager = SharedHealthCheckManager(redis_cache=None) + + # Should not raise exception + await manager.cache_health_check_results([], []) + + @pytest.mark.asyncio + async def test_perform_shared_health_check_with_cache(self, shared_health_manager, mock_redis_cache): + """Test performing shared health check when cache is available""" + # Mock cached results + cached_data = { + "healthy_endpoints": [{"model": "cached-model"}], + "unhealthy_endpoints": [], + "healthy_count": 1, + "unhealthy_count": 0, + "timestamp": time.time() - 100, + } + mock_redis_cache.async_get_cache.return_value = json.dumps(cached_data) + + model_list = [{"model_name": "test-model", "litellm_params": {"model": "test-model"}}] + + with patch("litellm.proxy.health_check_utils.shared_health_check_manager.perform_health_check") as mock_perform: + healthy, unhealthy = await shared_health_manager.perform_shared_health_check( + model_list, details=True + ) + + # Should return cached results, not call perform_health_check + assert healthy == [{"model": "cached-model"}] + assert unhealthy == [] + mock_perform.assert_not_called() + + @pytest.mark.asyncio + async def test_perform_shared_health_check_with_lock_acquisition(self, shared_health_manager, mock_redis_cache): + """Test performing shared health check when acquiring lock""" + # No cached results + mock_redis_cache.async_get_cache.return_value = None + # Lock acquisition succeeds + mock_redis_cache.async_set_cache.return_value = True + + model_list = [{"model_name": "test-model", "litellm_params": {"model": "test-model"}}] + expected_healthy = [{"model": "test-model", "status": "healthy"}] + expected_unhealthy = [] + + with patch("litellm.proxy.health_check_utils.shared_health_check_manager.perform_health_check") as mock_perform: + mock_perform.return_value = (expected_healthy, expected_unhealthy) + + healthy, unhealthy = await shared_health_manager.perform_shared_health_check( + model_list, details=True + ) + + # Should call perform_health_check and cache results + mock_perform.assert_called_once_with(model_list=model_list, details=True) + assert healthy == expected_healthy + assert unhealthy == expected_unhealthy + + # Should cache the results + assert mock_redis_cache.async_set_cache.call_count >= 2 # Lock + cache + + @pytest.mark.asyncio + async def test_perform_shared_health_check_lock_failed_then_cache(self, shared_health_manager, mock_redis_cache): + """Test performing shared health check when lock fails but cache becomes available""" + # First call: no cache, lock fails + # Second call: cache available + mock_redis_cache.async_get_cache.side_effect = [ + None, # No cache initially + json.dumps({ # Cache available after waiting + "healthy_endpoints": [{"model": "cached-model"}], + "unhealthy_endpoints": [], + "healthy_count": 1, + "unhealthy_count": 0, + "timestamp": time.time() - 100, + }) + ] + mock_redis_cache.async_set_cache.return_value = False # Lock acquisition fails + + model_list = [{"model_name": "test-model", "litellm_params": {"model": "test-model"}}] + + with patch("asyncio.sleep") as mock_sleep: # Mock sleep to avoid actual delay + healthy, unhealthy = await shared_health_manager.perform_shared_health_check( + model_list, details=True + ) + + # Should wait and then get cached results + mock_sleep.assert_called_once_with(2) + assert healthy == [{"model": "cached-model"}] + assert unhealthy == [] + + @pytest.mark.asyncio + async def test_perform_shared_health_check_fallback(self, shared_health_manager, mock_redis_cache): + """Test performing shared health check with fallback to local health check""" + # No cache, lock fails, no cache after waiting + mock_redis_cache.async_get_cache.return_value = None + mock_redis_cache.async_set_cache.return_value = False # Lock acquisition fails + + model_list = [{"model_name": "test-model", "litellm_params": {"model": "test-model"}}] + expected_healthy = [{"model": "test-model", "status": "healthy"}] + expected_unhealthy = [] + + with patch("asyncio.sleep") as mock_sleep, \ + patch("litellm.proxy.health_check_utils.shared_health_check_manager.perform_health_check") as mock_perform: + mock_perform.return_value = (expected_healthy, expected_unhealthy) + + healthy, unhealthy = await shared_health_manager.perform_shared_health_check( + model_list, details=True + ) + + # Should fall back to local health check + mock_sleep.assert_called_once_with(2) + mock_perform.assert_called_once_with(model_list=model_list, details=True) + assert healthy == expected_healthy + assert unhealthy == expected_unhealthy + + @pytest.mark.asyncio + async def test_is_health_check_in_progress_true(self, shared_health_manager, mock_redis_cache): + """Test checking if health check is in progress when it is""" + mock_redis_cache.async_get_cache.return_value = "other_pod_id" + + result = await shared_health_manager.is_health_check_in_progress() + + assert result is True + + @pytest.mark.asyncio + async def test_is_health_check_in_progress_false(self, shared_health_manager, mock_redis_cache): + """Test checking if health check is in progress when it's not""" + mock_redis_cache.async_get_cache.return_value = None + + result = await shared_health_manager.is_health_check_in_progress() + + assert result is False + + @pytest.mark.asyncio + async def test_is_health_check_in_progress_own_lock(self, shared_health_manager, mock_redis_cache): + """Test checking if health check is in progress when we own the lock""" + mock_redis_cache.async_get_cache.return_value = shared_health_manager.pod_id + + result = await shared_health_manager.is_health_check_in_progress() + + assert result is False + + @pytest.mark.asyncio + async def test_get_health_check_status(self, shared_health_manager, mock_redis_cache): + """Test getting health check status""" + current_time = time.time() + cached_data = { + "healthy_endpoints": [{"model": "test-model"}], + "unhealthy_endpoints": [], + "healthy_count": 1, + "unhealthy_count": 0, + "timestamp": current_time - 100, + "checked_by": "test_pod", + } + + mock_redis_cache.async_get_cache.side_effect = [ + "other_pod_id", # Lock owner + json.dumps(cached_data), # Cached results + ] + + status = await shared_health_manager.get_health_check_status() + + assert status["pod_id"] == shared_health_manager.pod_id + assert status["redis_available"] is True + assert status["lock_ttl"] == 60 + assert status["cache_ttl"] == 300 + assert status["lock_owner"] == "other_pod_id" + assert status["lock_in_progress"] is True + assert status["cache_available"] is True + assert status["last_checked_by"] == "test_pod" + assert "cache_age_seconds" in status + + @pytest.mark.asyncio + async def test_get_health_check_status_no_redis(self): + """Test getting health check status without Redis""" + manager = SharedHealthCheckManager(redis_cache=None) + + status = await manager.get_health_check_status() + + assert status["pod_id"] == manager.pod_id + assert status["redis_available"] is False + assert status["lock_ttl"] == 60 + assert status["cache_ttl"] == 300