[Fix] Ensure guardrail memory sync after database updates (#15633)

* chore: Consistency in install-test-deps using poetry run

* feat: update in-memory guardrails after database CRUD operations

* test: add parameterized tests for guardrail CRUD with memory sync
This commit is contained in:
Nicholas Couture
2025-10-17 15:46:49 +11:00
committed by GitHub
parent 814aeb5fad
commit 8032e73872
3 changed files with 437 additions and 12 deletions
+1 -1
View File
@@ -45,7 +45,7 @@ install-proxy-dev-ci:
install-test-deps: install-proxy-dev
poetry run pip install "pytest-retry==1.6.3"
poetry run pip install pytest-xdist
cd enterprise && python -m pip install -e . && cd ..
cd enterprise && poetry run pip install -e . && cd ..
install-helm-unittest:
helm plugin install https://github.com/helm-unittest/helm-unittest --version v0.4.4 || echo "ignore error if plugin exists"
+62 -10
View File
@@ -251,6 +251,7 @@ async def create_guardrail(request: CreateGuardrailRequest):
}
```
"""
from litellm.proxy.guardrails.guardrail_registry import IN_MEMORY_GUARDRAIL_HANDLER
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
@@ -260,6 +261,22 @@ async def create_guardrail(request: CreateGuardrailRequest):
result = await GUARDRAIL_REGISTRY.add_guardrail_to_db(
guardrail=request.guardrail, prisma_client=prisma_client
)
guardrail_name = result.get("guardrail_name", "Unknown")
guardrail_id = result.get("guardrail_id", "Unknown")
try:
IN_MEMORY_GUARDRAIL_HANDLER.initialize_guardrail(
guardrail=cast(Guardrail, result)
)
verbose_proxy_logger.info(
f"Immediate sync: Successfully initialized guardrail '{guardrail_name}' (ID: {guardrail_id})"
)
except Exception as init_error:
verbose_proxy_logger.warning(
f"Immediate sync: Failed to initialize guardrail '{guardrail_name}' (ID: {guardrail_id}) in memory: {init_error}"
)
return result
except Exception as e:
verbose_proxy_logger.exception(f"Error adding guardrail to db: {e}")
@@ -323,6 +340,7 @@ async def update_guardrail(guardrail_id: str, request: UpdateGuardrailRequest):
}
```
"""
from litellm.proxy.guardrails.guardrail_registry import IN_MEMORY_GUARDRAIL_HANDLER
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
@@ -344,6 +362,21 @@ async def update_guardrail(guardrail_id: str, request: UpdateGuardrailRequest):
guardrail=request.guardrail,
prisma_client=prisma_client,
)
guardrail_name = result.get("guardrail_name", "Unknown")
try:
IN_MEMORY_GUARDRAIL_HANDLER.update_in_memory_guardrail(
guardrail_id=guardrail_id, guardrail=cast(Guardrail, result)
)
verbose_proxy_logger.info(
f"Immediate sync: Successfully updated guardrail '{guardrail_name}' (ID: {guardrail_id})"
)
except Exception as update_error:
verbose_proxy_logger.warning(
f"Immediate sync: Failed to update '{guardrail_name}' (ID: {guardrail_id}) in memory: {update_error}"
)
return result
except HTTPException as e:
raise e
@@ -396,10 +429,20 @@ async def delete_guardrail(guardrail_id: str):
guardrail_id=guardrail_id, prisma_client=prisma_client
)
# delete in memory guardrail
IN_MEMORY_GUARDRAIL_HANDLER.delete_in_memory_guardrail(
guardrail_id=guardrail_id,
)
guardrail_name = result.get("guardrail_name", "Unknown")
try:
IN_MEMORY_GUARDRAIL_HANDLER.delete_in_memory_guardrail(
guardrail_id=guardrail_id,
)
verbose_proxy_logger.info(
f"Immediate sync: Successfully removed guardrail '{guardrail_name}' (ID: {guardrail_id}) from memory"
)
except Exception as delete_error:
verbose_proxy_logger.warning(
f"Immediate sync: Failed to remove guardrail '{guardrail_name}' (ID: {guardrail_id}) from memory: {delete_error}"
)
return result
except HTTPException as e:
raise e
@@ -514,11 +557,21 @@ async def patch_guardrail(guardrail_id: str, request: PatchGuardrailRequest):
prisma_client=prisma_client,
)
# update in memory guardrail
IN_MEMORY_GUARDRAIL_HANDLER.update_in_memory_guardrail(
guardrail_id=guardrail_id,
guardrail=guardrail,
)
guardrail_name = result.get("guardrail_name", "Unknown")
try:
IN_MEMORY_GUARDRAIL_HANDLER.update_in_memory_guardrail(
guardrail_id=guardrail_id,
guardrail=guardrail,
)
verbose_proxy_logger.info(
f"Immediate sync: Successfully updated guardrail '{guardrail_name}' (ID: {guardrail_id})"
)
except Exception as update_error:
verbose_proxy_logger.warning(
f"Immediate sync: Failed to update '{guardrail_name}' (ID: {guardrail_id}) in memory: {update_error}"
)
return result
except HTTPException as e:
raise e
@@ -760,7 +813,6 @@ def _extract_fields_recursive(
model: Type[BaseModel],
depth: int = 0,
) -> Dict[str, Any]:
# Check if we've exceeded the maximum recursion depth
if depth > DEFAULT_MAX_RECURSE_DEPTH:
raise HTTPException(
@@ -16,6 +16,13 @@ from fastapi import HTTPException
from litellm.proxy.guardrails.guardrail_endpoints import (
get_guardrail_info,
list_guardrails_v2,
CreateGuardrailRequest,
create_guardrail,
UpdateGuardrailRequest,
update_guardrail,
PatchGuardrailRequest,
patch_guardrail,
delete_guardrail,
)
from litellm.proxy.guardrails.guardrail_registry import (
IN_MEMORY_GUARDRAIL_HANDLER,
@@ -25,6 +32,7 @@ from litellm.types.guardrails import (
BaseLitellmParams,
GuardrailInfoResponse,
LitellmParams,
Guardrail,
)
# Mock data for testing
@@ -50,6 +58,20 @@ MOCK_CONFIG_GUARDRAIL = {
"guardrail_info": {"description": "Test guardrail from config"},
}
MOCK_GUARDRAIL = Guardrail(
guardrail_name=MOCK_CONFIG_GUARDRAIL["guardrail_name"],
litellm_params=LitellmParams(**MOCK_CONFIG_GUARDRAIL["litellm_params"]),
guardrail_info=MOCK_CONFIG_GUARDRAIL["guardrail_info"]
)
MOCK_CREATE_REQUEST = CreateGuardrailRequest(guardrail=MOCK_GUARDRAIL)
MOCK_UPDATE_REQUEST = UpdateGuardrailRequest(guardrail=MOCK_GUARDRAIL)
MOCK_PATCH_REQUEST = PatchGuardrailRequest(
guardrail_name="Updated Test Guardrail",
litellm_params={"guardrail": "updated.guardrail", "mode": "post_call"},
guardrail_info={"description": "Updated test guardrail"}
)
@pytest.fixture
def mock_prisma_client(mocker):
@@ -73,8 +95,23 @@ def mock_in_memory_handler(mocker):
mock_handler = mocker.Mock(spec=InMemoryGuardrailHandler)
mock_handler.list_in_memory_guardrails.return_value = [MOCK_CONFIG_GUARDRAIL]
mock_handler.get_guardrail_by_id.return_value = MOCK_CONFIG_GUARDRAIL
mock_handler.initialize_guardrail = mocker.Mock()
mock_handler.update_in_memory_guardrail = mocker.Mock()
mock_handler.delete_in_memory_guardrail = mocker.Mock()
return mock_handler
@pytest.fixture
def mock_guardrail_registry(mocker):
"""Mock GuardrailRegistry for testing"""
mock_registry = mocker.Mock()
mock_registry.add_guardrail_to_db = AsyncMock(return_value={
**MOCK_DB_GUARDRAIL,
"guardrail_id": "new-test-guardrail-id"
})
mock_registry.delete_guardrail_from_db = AsyncMock(return_value=MOCK_DB_GUARDRAIL)
mock_registry.get_guardrail_by_id_from_db = AsyncMock(return_value=MOCK_DB_GUARDRAIL)
mock_registry.update_guardrail_in_db = AsyncMock(return_value=MOCK_DB_GUARDRAIL)
return mock_registry
@pytest.mark.asyncio
async def test_list_guardrails_v2_with_db_and_config(
@@ -477,4 +514,340 @@ async def test_bedrock_guardrail_make_api_request_passes_api_key():
mock_aws_request.assert_called_once()
call_args = mock_aws_request.call_args
headers = call_args[1]["headers"]
assert headers["Authorization"] == "Bearer test-api-key-789"
assert headers["Authorization"] == "Bearer test-api-key-789"
@pytest.mark.parametrize("scenario,expected_result,expected_exception", [
(
"success_with_sync",
"new-test-guardrail-id",
None
),
(
"success_sync_fails",
"new-test-guardrail-id",
None
),
(
"database_failure",
None,
HTTPException
),
(
"no_prisma_client",
None,
HTTPException
),
], ids=[
"success_with_immediate_sync",
"success_but_sync_fails",
"database_error",
"missing_prisma_client"
])
@pytest.mark.asyncio
async def test_create_guardrail_endpoint(
scenario, expected_result, expected_exception,
mocker, mock_guardrail_registry, mock_in_memory_handler
):
"""Test create_guardrail endpoint with different scenarios"""
# Configure mocks based on scenario
mock_logger = None
if scenario == "success_with_sync":
mock_prisma_client = mocker.Mock()
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY", mock_guardrail_registry)
mocker.patch("litellm.proxy.guardrails.guardrail_registry.IN_MEMORY_GUARDRAIL_HANDLER", mock_in_memory_handler)
elif scenario == "success_sync_fails":
mock_prisma_client = mocker.Mock()
mock_in_memory_handler.initialize_guardrail.side_effect = Exception("Sync failed")
mock_logger = mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.verbose_proxy_logger")
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY", mock_guardrail_registry)
mocker.patch("litellm.proxy.guardrails.guardrail_registry.IN_MEMORY_GUARDRAIL_HANDLER", mock_in_memory_handler)
elif scenario == "database_failure":
mock_prisma_client = mocker.Mock()
mock_guardrail_registry.add_guardrail_to_db.side_effect = Exception("Database error")
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY", mock_guardrail_registry)
elif scenario == "no_prisma_client":
mocker.patch("litellm.proxy.proxy_server.prisma_client", None)
# Run the test
if expected_exception:
with pytest.raises(expected_exception) as exc_info:
await create_guardrail(MOCK_CREATE_REQUEST)
if scenario == "database_failure":
assert "Database error" in str(exc_info.value.detail)
elif scenario == "no_prisma_client":
assert "Prisma client not initialized" in str(exc_info.value.detail)
else:
result = await create_guardrail(MOCK_CREATE_REQUEST)
assert result["guardrail_id"] == expected_result
assert result["guardrail_name"] == "Test DB Guardrail"
mock_guardrail_registry.add_guardrail_to_db.assert_called_once_with(
guardrail=MOCK_CREATE_REQUEST.guardrail,
prisma_client=mocker.ANY
)
mock_in_memory_handler.initialize_guardrail.assert_called_once()
if scenario == "success_sync_fails":
assert mock_logger is not None
mock_logger.warning.assert_called_once()
assert "Failed to initialize guardrail" in str(mock_logger.warning.call_args)
@pytest.mark.parametrize("scenario,expected_result,expected_exception", [
(
"success_with_sync",
"test-db-guardrail",
None
),
(
"success_sync_fails",
"test-db-guardrail",
None
),
(
"database_failure",
None,
HTTPException
),
(
"no_prisma_client",
None,
HTTPException
),
], ids=[
"success_with_immediate_sync",
"success_but_sync_fails",
"database_error",
"missing_prisma_client"
])
@pytest.mark.asyncio
async def test_update_guardrail_endpoint(
scenario, expected_result, expected_exception,
mocker, mock_guardrail_registry, mock_in_memory_handler
):
"""Test update_guardrail endpoint with different scenarios"""
# Configure mocks based on scenario
mock_logger = None
if scenario == "success_with_sync":
mock_prisma_client = mocker.Mock()
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY", mock_guardrail_registry)
mocker.patch("litellm.proxy.guardrails.guardrail_registry.IN_MEMORY_GUARDRAIL_HANDLER", mock_in_memory_handler)
elif scenario == "success_sync_fails":
mock_prisma_client = mocker.Mock()
mock_in_memory_handler.update_in_memory_guardrail.side_effect = Exception("Sync failed")
mock_logger = mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.verbose_proxy_logger")
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY", mock_guardrail_registry)
mocker.patch("litellm.proxy.guardrails.guardrail_registry.IN_MEMORY_GUARDRAIL_HANDLER", mock_in_memory_handler)
elif scenario == "database_failure":
mock_prisma_client = mocker.Mock()
mock_guardrail_registry.update_guardrail_in_db.side_effect = Exception("Database error")
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY", mock_guardrail_registry)
elif scenario == "no_prisma_client":
mocker.patch("litellm.proxy.proxy_server.prisma_client", None)
# Run the test
if expected_exception:
with pytest.raises(expected_exception) as exc_info:
await update_guardrail("test-guardrail-id", MOCK_UPDATE_REQUEST)
if scenario == "database_failure":
assert "Database error" in str(exc_info.value.detail)
elif scenario == "no_prisma_client":
assert "Prisma client not initialized" in str(exc_info.value.detail)
else:
result = await update_guardrail("test-guardrail-id", MOCK_UPDATE_REQUEST)
assert result["guardrail_id"] == expected_result
assert result["guardrail_name"] == "Test DB Guardrail"
mock_guardrail_registry.update_guardrail_in_db.assert_called_once_with(
guardrail_id="test-guardrail-id",
guardrail=MOCK_UPDATE_REQUEST.guardrail,
prisma_client=mocker.ANY
)
mock_in_memory_handler.update_in_memory_guardrail.assert_called_once_with(
guardrail_id="test-guardrail-id",
guardrail=mocker.ANY
)
if scenario == "success_sync_fails":
assert mock_logger is not None
mock_logger.warning.assert_called_once()
assert "Failed to update" in str(mock_logger.warning.call_args)
@pytest.mark.parametrize("scenario,expected_result,expected_exception", [
(
"success_with_sync",
"test-db-guardrail",
None
),
(
"success_sync_fails",
"test-db-guardrail",
None
),
(
"database_failure",
None,
HTTPException
),
(
"no_prisma_client",
None,
HTTPException
),
], ids=[
"success_with_immediate_sync",
"success_but_sync_fails",
"database_error",
"missing_prisma_client"
])
@pytest.mark.asyncio
async def test_patch_guardrail_endpoint(
scenario, expected_result, expected_exception,
mocker, mock_guardrail_registry, mock_in_memory_handler
):
"""Test patch_guardrail endpoint with different scenarios"""
# Configure mocks based on scenario
mock_logger = None
if scenario == "success_with_sync":
mock_prisma_client = mocker.Mock()
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY", mock_guardrail_registry)
mocker.patch("litellm.proxy.guardrails.guardrail_registry.IN_MEMORY_GUARDRAIL_HANDLER", mock_in_memory_handler)
elif scenario == "success_sync_fails":
mock_prisma_client = mocker.Mock()
mock_in_memory_handler.update_in_memory_guardrail.side_effect = Exception("Sync failed")
mock_logger = mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.verbose_proxy_logger")
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY", mock_guardrail_registry)
mocker.patch("litellm.proxy.guardrails.guardrail_registry.IN_MEMORY_GUARDRAIL_HANDLER", mock_in_memory_handler)
elif scenario == "database_failure":
mock_prisma_client = mocker.Mock()
mock_guardrail_registry.update_guardrail_in_db.side_effect = Exception("Database error")
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY", mock_guardrail_registry)
elif scenario == "no_prisma_client":
mocker.patch("litellm.proxy.proxy_server.prisma_client", None)
# Run the test
if expected_exception:
with pytest.raises(expected_exception) as exc_info:
await patch_guardrail("test-guardrail-id", MOCK_PATCH_REQUEST)
if scenario == "database_failure":
assert "Database error" in str(exc_info.value.detail)
elif scenario == "no_prisma_client":
assert "Prisma client not initialized" in str(exc_info.value.detail)
else:
result = await patch_guardrail("test-guardrail-id", MOCK_PATCH_REQUEST)
assert result["guardrail_id"] == expected_result
assert result["guardrail_name"] == "Test DB Guardrail"
mock_guardrail_registry.update_guardrail_in_db.assert_called_once()
mock_in_memory_handler.update_in_memory_guardrail.assert_called_once_with(
guardrail_id="test-guardrail-id",
guardrail=mocker.ANY
)
if scenario == "success_sync_fails":
assert mock_logger is not None
mock_logger.warning.assert_called_once()
assert "Failed to update" in str(mock_logger.warning.call_args)
@pytest.mark.parametrize("scenario,expected_result,expected_exception", [
(
"success_with_sync",
"test-db-guardrail",
None
),
(
"success_sync_fails",
"test-db-guardrail",
None
),
], ids=[
"success_with_immediate_sync",
"success_but_sync_fails"
])
@pytest.mark.asyncio
async def test_delete_guardrail_endpoint(
scenario, expected_result, expected_exception,
mocker, mock_guardrail_registry, mock_in_memory_handler
):
"""Test delete_guardrail endpoint with different scenarios"""
# Configure mocks based on scenario
mock_prisma_client = mocker.Mock()
mock_logger = None
if scenario == "success_with_sync":
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY", mock_guardrail_registry)
mocker.patch("litellm.proxy.guardrails.guardrail_registry.IN_MEMORY_GUARDRAIL_HANDLER", mock_in_memory_handler)
elif scenario == "success_sync_fails":
mock_in_memory_handler.delete_in_memory_guardrail.side_effect = Exception("Sync failed")
mock_logger = mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.verbose_proxy_logger")
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY", mock_guardrail_registry)
mocker.patch("litellm.proxy.guardrails.guardrail_registry.IN_MEMORY_GUARDRAIL_HANDLER", mock_in_memory_handler)
if expected_exception:
with pytest.raises(expected_exception):
await delete_guardrail(guardrail_id=expected_result)
else:
result = await delete_guardrail(guardrail_id=expected_result)
assert result == MOCK_DB_GUARDRAIL
mock_guardrail_registry.get_guardrail_by_id_from_db.assert_called_once_with(
guardrail_id=expected_result,
prisma_client=mock_prisma_client
)
mock_guardrail_registry.delete_guardrail_from_db.assert_called_once_with(
guardrail_id=expected_result,
prisma_client=mock_prisma_client
)
mock_in_memory_handler.delete_in_memory_guardrail.assert_called_once_with(
guardrail_id=expected_result
)
if scenario == "success_sync_fails":
assert mock_logger is not None
mock_logger.warning.assert_called_once()
assert "Failed to remove guardrail" in str(mock_logger.warning.call_args)