mirror of
https://github.com/onyx-dot-app/litellm.git
synced 2026-07-01 20:44:04 -04:00
[Fix] Ensure guardrail memory sync after database updates (#15633)
* chore: Consistency in install-test-deps using poetry run * feat: update in-memory guardrails after database CRUD operations * test: add parameterized tests for guardrail CRUD with memory sync
This commit is contained in:
@@ -45,7 +45,7 @@ install-proxy-dev-ci:
|
||||
install-test-deps: install-proxy-dev
|
||||
poetry run pip install "pytest-retry==1.6.3"
|
||||
poetry run pip install pytest-xdist
|
||||
cd enterprise && python -m pip install -e . && cd ..
|
||||
cd enterprise && poetry run pip install -e . && cd ..
|
||||
|
||||
install-helm-unittest:
|
||||
helm plugin install https://github.com/helm-unittest/helm-unittest --version v0.4.4 || echo "ignore error if plugin exists"
|
||||
|
||||
@@ -251,6 +251,7 @@ async def create_guardrail(request: CreateGuardrailRequest):
|
||||
}
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.guardrails.guardrail_registry import IN_MEMORY_GUARDRAIL_HANDLER
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
@@ -260,6 +261,22 @@ async def create_guardrail(request: CreateGuardrailRequest):
|
||||
result = await GUARDRAIL_REGISTRY.add_guardrail_to_db(
|
||||
guardrail=request.guardrail, prisma_client=prisma_client
|
||||
)
|
||||
|
||||
guardrail_name = result.get("guardrail_name", "Unknown")
|
||||
guardrail_id = result.get("guardrail_id", "Unknown")
|
||||
|
||||
try:
|
||||
IN_MEMORY_GUARDRAIL_HANDLER.initialize_guardrail(
|
||||
guardrail=cast(Guardrail, result)
|
||||
)
|
||||
verbose_proxy_logger.info(
|
||||
f"Immediate sync: Successfully initialized guardrail '{guardrail_name}' (ID: {guardrail_id})"
|
||||
)
|
||||
except Exception as init_error:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Immediate sync: Failed to initialize guardrail '{guardrail_name}' (ID: {guardrail_id}) in memory: {init_error}"
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error adding guardrail to db: {e}")
|
||||
@@ -323,6 +340,7 @@ async def update_guardrail(guardrail_id: str, request: UpdateGuardrailRequest):
|
||||
}
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.guardrails.guardrail_registry import IN_MEMORY_GUARDRAIL_HANDLER
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
@@ -344,6 +362,21 @@ async def update_guardrail(guardrail_id: str, request: UpdateGuardrailRequest):
|
||||
guardrail=request.guardrail,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
guardrail_name = result.get("guardrail_name", "Unknown")
|
||||
|
||||
try:
|
||||
IN_MEMORY_GUARDRAIL_HANDLER.update_in_memory_guardrail(
|
||||
guardrail_id=guardrail_id, guardrail=cast(Guardrail, result)
|
||||
)
|
||||
verbose_proxy_logger.info(
|
||||
f"Immediate sync: Successfully updated guardrail '{guardrail_name}' (ID: {guardrail_id})"
|
||||
)
|
||||
except Exception as update_error:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Immediate sync: Failed to update '{guardrail_name}' (ID: {guardrail_id}) in memory: {update_error}"
|
||||
)
|
||||
|
||||
return result
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
@@ -396,10 +429,20 @@ async def delete_guardrail(guardrail_id: str):
|
||||
guardrail_id=guardrail_id, prisma_client=prisma_client
|
||||
)
|
||||
|
||||
# delete in memory guardrail
|
||||
IN_MEMORY_GUARDRAIL_HANDLER.delete_in_memory_guardrail(
|
||||
guardrail_id=guardrail_id,
|
||||
)
|
||||
guardrail_name = result.get("guardrail_name", "Unknown")
|
||||
|
||||
try:
|
||||
IN_MEMORY_GUARDRAIL_HANDLER.delete_in_memory_guardrail(
|
||||
guardrail_id=guardrail_id,
|
||||
)
|
||||
verbose_proxy_logger.info(
|
||||
f"Immediate sync: Successfully removed guardrail '{guardrail_name}' (ID: {guardrail_id}) from memory"
|
||||
)
|
||||
except Exception as delete_error:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Immediate sync: Failed to remove guardrail '{guardrail_name}' (ID: {guardrail_id}) from memory: {delete_error}"
|
||||
)
|
||||
|
||||
return result
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
@@ -514,11 +557,21 @@ async def patch_guardrail(guardrail_id: str, request: PatchGuardrailRequest):
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
# update in memory guardrail
|
||||
IN_MEMORY_GUARDRAIL_HANDLER.update_in_memory_guardrail(
|
||||
guardrail_id=guardrail_id,
|
||||
guardrail=guardrail,
|
||||
)
|
||||
guardrail_name = result.get("guardrail_name", "Unknown")
|
||||
|
||||
try:
|
||||
IN_MEMORY_GUARDRAIL_HANDLER.update_in_memory_guardrail(
|
||||
guardrail_id=guardrail_id,
|
||||
guardrail=guardrail,
|
||||
)
|
||||
verbose_proxy_logger.info(
|
||||
f"Immediate sync: Successfully updated guardrail '{guardrail_name}' (ID: {guardrail_id})"
|
||||
)
|
||||
except Exception as update_error:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Immediate sync: Failed to update '{guardrail_name}' (ID: {guardrail_id}) in memory: {update_error}"
|
||||
)
|
||||
|
||||
return result
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
@@ -760,7 +813,6 @@ def _extract_fields_recursive(
|
||||
model: Type[BaseModel],
|
||||
depth: int = 0,
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
# Check if we've exceeded the maximum recursion depth
|
||||
if depth > DEFAULT_MAX_RECURSE_DEPTH:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -16,6 +16,13 @@ from fastapi import HTTPException
|
||||
from litellm.proxy.guardrails.guardrail_endpoints import (
|
||||
get_guardrail_info,
|
||||
list_guardrails_v2,
|
||||
CreateGuardrailRequest,
|
||||
create_guardrail,
|
||||
UpdateGuardrailRequest,
|
||||
update_guardrail,
|
||||
PatchGuardrailRequest,
|
||||
patch_guardrail,
|
||||
delete_guardrail,
|
||||
)
|
||||
from litellm.proxy.guardrails.guardrail_registry import (
|
||||
IN_MEMORY_GUARDRAIL_HANDLER,
|
||||
@@ -25,6 +32,7 @@ from litellm.types.guardrails import (
|
||||
BaseLitellmParams,
|
||||
GuardrailInfoResponse,
|
||||
LitellmParams,
|
||||
Guardrail,
|
||||
)
|
||||
|
||||
# Mock data for testing
|
||||
@@ -50,6 +58,20 @@ MOCK_CONFIG_GUARDRAIL = {
|
||||
"guardrail_info": {"description": "Test guardrail from config"},
|
||||
}
|
||||
|
||||
MOCK_GUARDRAIL = Guardrail(
|
||||
guardrail_name=MOCK_CONFIG_GUARDRAIL["guardrail_name"],
|
||||
litellm_params=LitellmParams(**MOCK_CONFIG_GUARDRAIL["litellm_params"]),
|
||||
guardrail_info=MOCK_CONFIG_GUARDRAIL["guardrail_info"]
|
||||
)
|
||||
|
||||
MOCK_CREATE_REQUEST = CreateGuardrailRequest(guardrail=MOCK_GUARDRAIL)
|
||||
MOCK_UPDATE_REQUEST = UpdateGuardrailRequest(guardrail=MOCK_GUARDRAIL)
|
||||
MOCK_PATCH_REQUEST = PatchGuardrailRequest(
|
||||
guardrail_name="Updated Test Guardrail",
|
||||
litellm_params={"guardrail": "updated.guardrail", "mode": "post_call"},
|
||||
guardrail_info={"description": "Updated test guardrail"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prisma_client(mocker):
|
||||
@@ -73,8 +95,23 @@ def mock_in_memory_handler(mocker):
|
||||
mock_handler = mocker.Mock(spec=InMemoryGuardrailHandler)
|
||||
mock_handler.list_in_memory_guardrails.return_value = [MOCK_CONFIG_GUARDRAIL]
|
||||
mock_handler.get_guardrail_by_id.return_value = MOCK_CONFIG_GUARDRAIL
|
||||
mock_handler.initialize_guardrail = mocker.Mock()
|
||||
mock_handler.update_in_memory_guardrail = mocker.Mock()
|
||||
mock_handler.delete_in_memory_guardrail = mocker.Mock()
|
||||
return mock_handler
|
||||
|
||||
@pytest.fixture
|
||||
def mock_guardrail_registry(mocker):
|
||||
"""Mock GuardrailRegistry for testing"""
|
||||
mock_registry = mocker.Mock()
|
||||
mock_registry.add_guardrail_to_db = AsyncMock(return_value={
|
||||
**MOCK_DB_GUARDRAIL,
|
||||
"guardrail_id": "new-test-guardrail-id"
|
||||
})
|
||||
mock_registry.delete_guardrail_from_db = AsyncMock(return_value=MOCK_DB_GUARDRAIL)
|
||||
mock_registry.get_guardrail_by_id_from_db = AsyncMock(return_value=MOCK_DB_GUARDRAIL)
|
||||
mock_registry.update_guardrail_in_db = AsyncMock(return_value=MOCK_DB_GUARDRAIL)
|
||||
return mock_registry
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_guardrails_v2_with_db_and_config(
|
||||
@@ -477,4 +514,340 @@ async def test_bedrock_guardrail_make_api_request_passes_api_key():
|
||||
mock_aws_request.assert_called_once()
|
||||
call_args = mock_aws_request.call_args
|
||||
headers = call_args[1]["headers"]
|
||||
assert headers["Authorization"] == "Bearer test-api-key-789"
|
||||
assert headers["Authorization"] == "Bearer test-api-key-789"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("scenario,expected_result,expected_exception", [
|
||||
(
|
||||
"success_with_sync",
|
||||
"new-test-guardrail-id",
|
||||
None
|
||||
),
|
||||
(
|
||||
"success_sync_fails",
|
||||
"new-test-guardrail-id",
|
||||
None
|
||||
),
|
||||
(
|
||||
"database_failure",
|
||||
None,
|
||||
HTTPException
|
||||
),
|
||||
(
|
||||
"no_prisma_client",
|
||||
None,
|
||||
HTTPException
|
||||
),
|
||||
], ids=[
|
||||
"success_with_immediate_sync",
|
||||
"success_but_sync_fails",
|
||||
"database_error",
|
||||
"missing_prisma_client"
|
||||
])
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_guardrail_endpoint(
|
||||
scenario, expected_result, expected_exception,
|
||||
mocker, mock_guardrail_registry, mock_in_memory_handler
|
||||
):
|
||||
"""Test create_guardrail endpoint with different scenarios"""
|
||||
|
||||
# Configure mocks based on scenario
|
||||
mock_logger = None
|
||||
if scenario == "success_with_sync":
|
||||
mock_prisma_client = mocker.Mock()
|
||||
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||||
mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY", mock_guardrail_registry)
|
||||
mocker.patch("litellm.proxy.guardrails.guardrail_registry.IN_MEMORY_GUARDRAIL_HANDLER", mock_in_memory_handler)
|
||||
|
||||
elif scenario == "success_sync_fails":
|
||||
mock_prisma_client = mocker.Mock()
|
||||
mock_in_memory_handler.initialize_guardrail.side_effect = Exception("Sync failed")
|
||||
mock_logger = mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.verbose_proxy_logger")
|
||||
|
||||
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||||
mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY", mock_guardrail_registry)
|
||||
mocker.patch("litellm.proxy.guardrails.guardrail_registry.IN_MEMORY_GUARDRAIL_HANDLER", mock_in_memory_handler)
|
||||
|
||||
elif scenario == "database_failure":
|
||||
mock_prisma_client = mocker.Mock()
|
||||
mock_guardrail_registry.add_guardrail_to_db.side_effect = Exception("Database error")
|
||||
|
||||
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||||
mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY", mock_guardrail_registry)
|
||||
|
||||
elif scenario == "no_prisma_client":
|
||||
mocker.patch("litellm.proxy.proxy_server.prisma_client", None)
|
||||
|
||||
# Run the test
|
||||
if expected_exception:
|
||||
with pytest.raises(expected_exception) as exc_info:
|
||||
await create_guardrail(MOCK_CREATE_REQUEST)
|
||||
|
||||
if scenario == "database_failure":
|
||||
assert "Database error" in str(exc_info.value.detail)
|
||||
elif scenario == "no_prisma_client":
|
||||
assert "Prisma client not initialized" in str(exc_info.value.detail)
|
||||
|
||||
else:
|
||||
result = await create_guardrail(MOCK_CREATE_REQUEST)
|
||||
|
||||
assert result["guardrail_id"] == expected_result
|
||||
assert result["guardrail_name"] == "Test DB Guardrail"
|
||||
|
||||
mock_guardrail_registry.add_guardrail_to_db.assert_called_once_with(
|
||||
guardrail=MOCK_CREATE_REQUEST.guardrail,
|
||||
prisma_client=mocker.ANY
|
||||
)
|
||||
|
||||
mock_in_memory_handler.initialize_guardrail.assert_called_once()
|
||||
|
||||
if scenario == "success_sync_fails":
|
||||
assert mock_logger is not None
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert "Failed to initialize guardrail" in str(mock_logger.warning.call_args)
|
||||
|
||||
@pytest.mark.parametrize("scenario,expected_result,expected_exception", [
|
||||
(
|
||||
"success_with_sync",
|
||||
"test-db-guardrail",
|
||||
None
|
||||
),
|
||||
(
|
||||
"success_sync_fails",
|
||||
"test-db-guardrail",
|
||||
None
|
||||
),
|
||||
(
|
||||
"database_failure",
|
||||
None,
|
||||
HTTPException
|
||||
),
|
||||
(
|
||||
"no_prisma_client",
|
||||
None,
|
||||
HTTPException
|
||||
),
|
||||
], ids=[
|
||||
"success_with_immediate_sync",
|
||||
"success_but_sync_fails",
|
||||
"database_error",
|
||||
"missing_prisma_client"
|
||||
])
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_guardrail_endpoint(
|
||||
scenario, expected_result, expected_exception,
|
||||
mocker, mock_guardrail_registry, mock_in_memory_handler
|
||||
):
|
||||
"""Test update_guardrail endpoint with different scenarios"""
|
||||
|
||||
# Configure mocks based on scenario
|
||||
mock_logger = None
|
||||
if scenario == "success_with_sync":
|
||||
mock_prisma_client = mocker.Mock()
|
||||
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||||
mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY", mock_guardrail_registry)
|
||||
mocker.patch("litellm.proxy.guardrails.guardrail_registry.IN_MEMORY_GUARDRAIL_HANDLER", mock_in_memory_handler)
|
||||
|
||||
elif scenario == "success_sync_fails":
|
||||
mock_prisma_client = mocker.Mock()
|
||||
mock_in_memory_handler.update_in_memory_guardrail.side_effect = Exception("Sync failed")
|
||||
mock_logger = mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.verbose_proxy_logger")
|
||||
|
||||
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||||
mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY", mock_guardrail_registry)
|
||||
mocker.patch("litellm.proxy.guardrails.guardrail_registry.IN_MEMORY_GUARDRAIL_HANDLER", mock_in_memory_handler)
|
||||
|
||||
elif scenario == "database_failure":
|
||||
mock_prisma_client = mocker.Mock()
|
||||
mock_guardrail_registry.update_guardrail_in_db.side_effect = Exception("Database error")
|
||||
|
||||
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||||
mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY", mock_guardrail_registry)
|
||||
|
||||
elif scenario == "no_prisma_client":
|
||||
mocker.patch("litellm.proxy.proxy_server.prisma_client", None)
|
||||
|
||||
# Run the test
|
||||
if expected_exception:
|
||||
with pytest.raises(expected_exception) as exc_info:
|
||||
await update_guardrail("test-guardrail-id", MOCK_UPDATE_REQUEST)
|
||||
|
||||
if scenario == "database_failure":
|
||||
assert "Database error" in str(exc_info.value.detail)
|
||||
elif scenario == "no_prisma_client":
|
||||
assert "Prisma client not initialized" in str(exc_info.value.detail)
|
||||
|
||||
else:
|
||||
result = await update_guardrail("test-guardrail-id", MOCK_UPDATE_REQUEST)
|
||||
|
||||
assert result["guardrail_id"] == expected_result
|
||||
assert result["guardrail_name"] == "Test DB Guardrail"
|
||||
|
||||
mock_guardrail_registry.update_guardrail_in_db.assert_called_once_with(
|
||||
guardrail_id="test-guardrail-id",
|
||||
guardrail=MOCK_UPDATE_REQUEST.guardrail,
|
||||
prisma_client=mocker.ANY
|
||||
)
|
||||
|
||||
mock_in_memory_handler.update_in_memory_guardrail.assert_called_once_with(
|
||||
guardrail_id="test-guardrail-id",
|
||||
guardrail=mocker.ANY
|
||||
)
|
||||
|
||||
if scenario == "success_sync_fails":
|
||||
assert mock_logger is not None
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert "Failed to update" in str(mock_logger.warning.call_args)
|
||||
|
||||
@pytest.mark.parametrize("scenario,expected_result,expected_exception", [
|
||||
(
|
||||
"success_with_sync",
|
||||
"test-db-guardrail",
|
||||
None
|
||||
),
|
||||
(
|
||||
"success_sync_fails",
|
||||
"test-db-guardrail",
|
||||
None
|
||||
),
|
||||
(
|
||||
"database_failure",
|
||||
None,
|
||||
HTTPException
|
||||
),
|
||||
(
|
||||
"no_prisma_client",
|
||||
None,
|
||||
HTTPException
|
||||
),
|
||||
], ids=[
|
||||
"success_with_immediate_sync",
|
||||
"success_but_sync_fails",
|
||||
"database_error",
|
||||
"missing_prisma_client"
|
||||
])
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_guardrail_endpoint(
|
||||
scenario, expected_result, expected_exception,
|
||||
mocker, mock_guardrail_registry, mock_in_memory_handler
|
||||
):
|
||||
"""Test patch_guardrail endpoint with different scenarios"""
|
||||
|
||||
# Configure mocks based on scenario
|
||||
mock_logger = None
|
||||
if scenario == "success_with_sync":
|
||||
mock_prisma_client = mocker.Mock()
|
||||
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||||
mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY", mock_guardrail_registry)
|
||||
mocker.patch("litellm.proxy.guardrails.guardrail_registry.IN_MEMORY_GUARDRAIL_HANDLER", mock_in_memory_handler)
|
||||
|
||||
elif scenario == "success_sync_fails":
|
||||
mock_prisma_client = mocker.Mock()
|
||||
mock_in_memory_handler.update_in_memory_guardrail.side_effect = Exception("Sync failed")
|
||||
mock_logger = mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.verbose_proxy_logger")
|
||||
|
||||
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||||
mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY", mock_guardrail_registry)
|
||||
mocker.patch("litellm.proxy.guardrails.guardrail_registry.IN_MEMORY_GUARDRAIL_HANDLER", mock_in_memory_handler)
|
||||
|
||||
elif scenario == "database_failure":
|
||||
mock_prisma_client = mocker.Mock()
|
||||
mock_guardrail_registry.update_guardrail_in_db.side_effect = Exception("Database error")
|
||||
|
||||
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||||
mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY", mock_guardrail_registry)
|
||||
|
||||
elif scenario == "no_prisma_client":
|
||||
mocker.patch("litellm.proxy.proxy_server.prisma_client", None)
|
||||
|
||||
# Run the test
|
||||
if expected_exception:
|
||||
with pytest.raises(expected_exception) as exc_info:
|
||||
await patch_guardrail("test-guardrail-id", MOCK_PATCH_REQUEST)
|
||||
|
||||
if scenario == "database_failure":
|
||||
assert "Database error" in str(exc_info.value.detail)
|
||||
elif scenario == "no_prisma_client":
|
||||
assert "Prisma client not initialized" in str(exc_info.value.detail)
|
||||
|
||||
else:
|
||||
result = await patch_guardrail("test-guardrail-id", MOCK_PATCH_REQUEST)
|
||||
|
||||
assert result["guardrail_id"] == expected_result
|
||||
assert result["guardrail_name"] == "Test DB Guardrail"
|
||||
|
||||
mock_guardrail_registry.update_guardrail_in_db.assert_called_once()
|
||||
|
||||
mock_in_memory_handler.update_in_memory_guardrail.assert_called_once_with(
|
||||
guardrail_id="test-guardrail-id",
|
||||
guardrail=mocker.ANY
|
||||
)
|
||||
|
||||
if scenario == "success_sync_fails":
|
||||
assert mock_logger is not None
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert "Failed to update" in str(mock_logger.warning.call_args)
|
||||
|
||||
@pytest.mark.parametrize("scenario,expected_result,expected_exception", [
|
||||
(
|
||||
"success_with_sync",
|
||||
"test-db-guardrail",
|
||||
None
|
||||
),
|
||||
(
|
||||
"success_sync_fails",
|
||||
"test-db-guardrail",
|
||||
None
|
||||
),
|
||||
], ids=[
|
||||
"success_with_immediate_sync",
|
||||
"success_but_sync_fails"
|
||||
])
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_guardrail_endpoint(
|
||||
scenario, expected_result, expected_exception,
|
||||
mocker, mock_guardrail_registry, mock_in_memory_handler
|
||||
):
|
||||
"""Test delete_guardrail endpoint with different scenarios"""
|
||||
|
||||
# Configure mocks based on scenario
|
||||
mock_prisma_client = mocker.Mock()
|
||||
mock_logger = None
|
||||
|
||||
if scenario == "success_with_sync":
|
||||
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||||
mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY", mock_guardrail_registry)
|
||||
mocker.patch("litellm.proxy.guardrails.guardrail_registry.IN_MEMORY_GUARDRAIL_HANDLER", mock_in_memory_handler)
|
||||
|
||||
elif scenario == "success_sync_fails":
|
||||
mock_in_memory_handler.delete_in_memory_guardrail.side_effect = Exception("Sync failed")
|
||||
mock_logger = mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.verbose_proxy_logger")
|
||||
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||||
mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY", mock_guardrail_registry)
|
||||
mocker.patch("litellm.proxy.guardrails.guardrail_registry.IN_MEMORY_GUARDRAIL_HANDLER", mock_in_memory_handler)
|
||||
|
||||
if expected_exception:
|
||||
with pytest.raises(expected_exception):
|
||||
await delete_guardrail(guardrail_id=expected_result)
|
||||
else:
|
||||
result = await delete_guardrail(guardrail_id=expected_result)
|
||||
|
||||
assert result == MOCK_DB_GUARDRAIL
|
||||
|
||||
mock_guardrail_registry.get_guardrail_by_id_from_db.assert_called_once_with(
|
||||
guardrail_id=expected_result,
|
||||
prisma_client=mock_prisma_client
|
||||
)
|
||||
mock_guardrail_registry.delete_guardrail_from_db.assert_called_once_with(
|
||||
guardrail_id=expected_result,
|
||||
prisma_client=mock_prisma_client
|
||||
)
|
||||
|
||||
mock_in_memory_handler.delete_in_memory_guardrail.assert_called_once_with(
|
||||
guardrail_id=expected_result
|
||||
)
|
||||
|
||||
if scenario == "success_sync_fails":
|
||||
assert mock_logger is not None
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert "Failed to remove guardrail" in str(mock_logger.warning.call_args)
|
||||
Reference in New Issue
Block a user