Spaces:
Paused
Paused
| import io | |
| import os | |
| import sys | |
| sys.path.insert(0, os.path.abspath("../..")) | |
| import asyncio | |
| import litellm | |
| import gzip | |
| import json | |
| import logging | |
| import time | |
| from typing import Optional, List | |
| from unittest.mock import AsyncMock, patch, Mock | |
| import pytest | |
| import litellm | |
| from litellm import completion | |
| from litellm._logging import verbose_logger | |
| from litellm.integrations.vector_stores.bedrock_vector_store import BedrockVectorStore | |
| from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler | |
| from litellm.integrations.custom_logger import CustomLogger | |
| from litellm.types.utils import StandardLoggingPayload, StandardLoggingVectorStoreRequest | |
| from litellm.types.vector_stores import VectorStoreSearchResponse | |
| class TestCustomLogger(CustomLogger): | |
| def __init__(self): | |
| self.standard_logging_payload: Optional[StandardLoggingPayload] = None | |
| super().__init__() | |
| async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): | |
| self.standard_logging_payload = kwargs.get("standard_logging_object") | |
| pass | |
| def add_aws_region_to_env(monkeypatch): | |
| monkeypatch.setenv("AWS_REGION", "us-west-2") | |
| def setup_vector_store_registry(): | |
| from litellm.vector_stores.vector_store_registry import VectorStoreRegistry, LiteLLM_ManagedVectorStore | |
| # Init vector store registry | |
| litellm.vector_store_registry = VectorStoreRegistry( | |
| vector_stores=[ | |
| LiteLLM_ManagedVectorStore( | |
| vector_store_id="T37J8R4WTM", | |
| custom_llm_provider="bedrock" | |
| ) | |
| ] | |
| ) | |
| async def test_basic_bedrock_knowledgebase_retrieval(setup_vector_store_registry): | |
| bedrock_knowledgebase_hook = BedrockVectorStore(aws_region_name="us-west-2") | |
| response = await bedrock_knowledgebase_hook.make_bedrock_kb_retrieve_request( | |
| knowledge_base_id="T37J8R4WTM", | |
| query="what is litellm?", | |
| ) | |
| assert response is not None | |
| async def test_e2e_bedrock_knowledgebase_retrieval_with_completion(setup_vector_store_registry): | |
| litellm._turn_on_debug() | |
| client = AsyncHTTPHandler() | |
| print("value of litellm.vector_store_registry:", litellm.vector_store_registry) | |
| with patch.object(client, "post") as mock_post: | |
| # Mock the response for the LLM call | |
| mock_response = Mock() | |
| mock_response.status_code = 200 | |
| mock_response.headers = {"Content-Type": "application/json"} | |
| mock_response.json = lambda: json.loads(mock_response.text) | |
| mock_post.return_value = mock_response | |
| try: | |
| response = await litellm.acompletion( | |
| model="anthropic/claude-3.5-sonnet", | |
| messages=[{"role": "user", "content": "what is litellm?"}], | |
| vector_store_ids = [ | |
| "T37J8R4WTM" | |
| ], | |
| client=client | |
| ) | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| # Verify the LLM request was made | |
| mock_post.assert_called_once() | |
| # Verify the request body | |
| print("call args:", mock_post.call_args) | |
| request_body = mock_post.call_args.kwargs["json"] | |
| print("Request body:", json.dumps(request_body, indent=4, default=str)) | |
| # Assert content from the knowedge base was applied to the request | |
| # 1. we should have 2 content blocks, the first is the user message, the second is the context from the knowledge base | |
| content = request_body["messages"][0]["content"] | |
| assert len(content) == 2 | |
| assert content[0]["type"] == "text" | |
| assert content[1]["type"] == "text" | |
| # 2. the message with the context should have the bedrock knowledge base prefix string | |
| # this helps confirm that the context from the knowledge base was applied to the request | |
| assert BedrockVectorStore.CONTENT_PREFIX_STRING in content[1]["text"] | |
| async def test_e2e_bedrock_knowledgebase_retrieval_with_llm_api_call(setup_vector_store_registry): | |
| """ | |
| Test that the Bedrock Knowledge Base Hook works when making a real llm api call | |
| """ | |
| # Init client | |
| litellm._turn_on_debug() | |
| async_client = AsyncHTTPHandler() | |
| litellm.callbacks = [BedrockVectorStore(aws_region_name="us-west-2")] | |
| response = await litellm.acompletion( | |
| model="anthropic/claude-3-5-haiku-latest", | |
| messages=[{"role": "user", "content": "what is litellm?"}], | |
| vector_store_ids = [ | |
| "T37J8R4WTM" | |
| ], | |
| client=async_client | |
| ) | |
| assert response is not None | |
| async def test_openai_with_knowledge_base_mock_openai(setup_vector_store_registry): | |
| """ | |
| Tests that knowledge base content is correctly passed to the OpenAI API call | |
| """ | |
| litellm.callbacks = [BedrockVectorStore(aws_region_name="us-west-2")] | |
| litellm.set_verbose = True | |
| from openai import AsyncOpenAI | |
| client = AsyncOpenAI(api_key="fake-api-key") | |
| with patch.object( | |
| client.chat.completions.with_raw_response, "create" | |
| ) as mock_client: | |
| try: | |
| await litellm.acompletion( | |
| model="gpt-4", | |
| messages=[{"role": "user", "content": "what is litellm?"}], | |
| vector_store_ids = [ | |
| "T37J8R4WTM" | |
| ], | |
| client=client, | |
| ) | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| # Verify the API was called | |
| mock_client.assert_called_once() | |
| request_body = mock_client.call_args.kwargs | |
| # Verify the request contains messages with knowledge base context | |
| assert "messages" in request_body | |
| messages = request_body["messages"] | |
| # We expect at least 2 messages: | |
| # 1. User message with the question | |
| # 2. User message with the knowledge base context | |
| assert len(messages) >= 2 | |
| print("request messages:", json.dumps(messages, indent=4, default=str)) | |
| # assert message[1] is the user message with the knowledge base context | |
| assert messages[1]["role"] == "user" | |
| assert BedrockVectorStore.CONTENT_PREFIX_STRING in messages[1]["content"] | |
| async def test_openai_with_vector_store_ids_in_tool_call_mock_openai(setup_vector_store_registry): | |
| """ | |
| Tests that vector store ids can be passed as tools | |
| This is the OpenAI format | |
| """ | |
| litellm.callbacks = [BedrockVectorStore(aws_region_name="us-west-2")] | |
| litellm.set_verbose = True | |
| from openai import AsyncOpenAI | |
| client = AsyncOpenAI(api_key="fake-api-key") | |
| with patch.object( | |
| client.chat.completions.with_raw_response, "create" | |
| ) as mock_client: | |
| try: | |
| await litellm.acompletion( | |
| model="gpt-4", | |
| messages=[{"role": "user", "content": "what is litellm?"}], | |
| tools=[{ | |
| "type": "file_search", | |
| "vector_store_ids": ["T37J8R4WTM"] | |
| }], | |
| client=client, | |
| ) | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| # Verify the API was called | |
| mock_client.assert_called_once() | |
| request_body = mock_client.call_args.kwargs | |
| print("request body:", json.dumps(request_body, indent=4, default=str)) | |
| # Verify the request contains messages with knowledge base context | |
| assert "messages" in request_body | |
| messages = request_body["messages"] | |
| # We expect at least 2 messages: | |
| # 1. User message with the question | |
| # 2. User message with the knowledge base context | |
| assert len(messages) >= 2 | |
| print("request messages:", json.dumps(messages, indent=4, default=str)) | |
| # assert message[1] is the user message with the knowledge base context | |
| assert messages[1]["role"] == "user" | |
| assert BedrockVectorStore.CONTENT_PREFIX_STRING in messages[1]["content"] | |
| # assert that the tool call was not sent to the upstream llm API if it's a litellm vector store | |
| assert "tools" not in request_body | |
| async def test_openai_with_mixed_tool_call_mock_openai(setup_vector_store_registry): | |
| """Ensure unrecognized vector store tools are forwarded to the provider""" | |
| litellm.callbacks = [BedrockVectorStore(aws_region_name="us-west-2")] | |
| from openai import AsyncOpenAI | |
| client = AsyncOpenAI(api_key="fake-api-key") | |
| with patch.object( | |
| client.chat.completions.with_raw_response, "create" | |
| ) as mock_client: | |
| try: | |
| await litellm.acompletion( | |
| model="gpt-4", | |
| messages=[{"role": "user", "content": "what is litellm?"}], | |
| tools=[ | |
| {"type": "file_search", "vector_store_ids": ["T37J8R4WTM"]}, | |
| {"type": "file_search", "vector_store_ids": ["unknownVS"]}, | |
| ], | |
| client=client, | |
| ) | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| mock_client.assert_called_once() | |
| request_body = mock_client.call_args.kwargs | |
| assert "messages" in request_body | |
| messages = request_body["messages"] | |
| assert len(messages) >= 2 | |
| assert messages[1]["role"] == "user" | |
| assert BedrockVectorStore.CONTENT_PREFIX_STRING in messages[1]["content"] | |
| assert "tools" in request_body | |
| tools = request_body["tools"] | |
| assert len(tools) == 1 | |
| assert tools[0]["vector_store_ids"] == ["unknownVS"] | |
| async def test_logging_with_knowledge_base_hook(setup_vector_store_registry): | |
| """ | |
| Test that the knowledge base request was logged in standard logging payload | |
| """ | |
| test_custom_logger = TestCustomLogger() | |
| litellm.callbacks = [BedrockVectorStore(aws_region_name="us-west-2"), test_custom_logger] | |
| litellm.set_verbose = True | |
| await litellm.acompletion( | |
| model="gpt-4", | |
| messages=[{"role": "user", "content": "what is litellm?"}], | |
| vector_store_ids = [ | |
| "T37J8R4WTM" | |
| ], | |
| ) | |
| # sleep for 1 second to allow the logging callback to run | |
| await asyncio.sleep(1) | |
| # assert that the knowledge base request was logged in the standard logging payload | |
| standard_logging_payload: Optional[StandardLoggingPayload] = test_custom_logger.standard_logging_payload | |
| assert standard_logging_payload is not None | |
| metadata = standard_logging_payload["metadata"] | |
| standard_logging_vector_store_request_metadata: Optional[List[StandardLoggingVectorStoreRequest]] = metadata["vector_store_request_metadata"] | |
| print("standard_logging_vector_store_request_metadata:", json.dumps(standard_logging_vector_store_request_metadata, indent=4, default=str)) | |
| # 1 vector store request was made, expect 1 vector store request metadata object | |
| assert len(standard_logging_vector_store_request_metadata) == 1 | |
| # expect the vector store request metadata object to have the correct values | |
| vector_store_request_metadata = standard_logging_vector_store_request_metadata[0] | |
| assert vector_store_request_metadata.get("vector_store_id") == "T37J8R4WTM" | |
| assert vector_store_request_metadata.get("query") == "what is litellm?" | |
| assert vector_store_request_metadata.get("custom_llm_provider") == "bedrock" | |
| vector_store_search_response: VectorStoreSearchResponse = vector_store_request_metadata.get("vector_store_search_response") | |
| assert vector_store_search_response is not None | |
| assert vector_store_search_response.get("search_query") == "what is litellm?" | |
| assert len(vector_store_search_response.get("data", [])) >=0 | |
| for item in vector_store_search_response.get("data", []): | |
| assert item.get("score") is not None | |
| assert item.get("content") is not None | |
| assert len(item.get("content", [])) >= 0 | |
| for content_item in item.get("content", []): | |
| text_content = content_item.get("text") | |
| assert text_content is not None | |
| assert len(text_content) > 0 | |
| async def test_logging_with_knowledge_base_hook_no_vector_store_registry(setup_vector_store_registry): | |
| """ | |
| Test that the knowledge base request was logged in standard logging payload | |
| """ | |
| test_custom_logger = TestCustomLogger() | |
| litellm.callbacks = [BedrockVectorStore(aws_region_name="us-west-2"), test_custom_logger] | |
| litellm.vector_store_registry = None | |
| await litellm.acompletion( | |
| model="gpt-4", | |
| messages=[{"role": "user", "content": "what is litellm?"}], | |
| ) | |
| async def test_e2e_bedrock_knowledgebase_retrieval_without_vector_store_registry(setup_vector_store_registry): | |
| litellm._turn_on_debug() | |
| client = AsyncHTTPHandler() | |
| litellm.vector_store_registry = None | |
| with patch.object(client, "post") as mock_post: | |
| # Mock the response for the LLM call | |
| mock_response = Mock() | |
| mock_response.status_code = 200 | |
| mock_response.headers = {"Content-Type": "application/json"} | |
| mock_response.json = lambda: json.loads(mock_response.text) | |
| mock_post.return_value = mock_response | |
| try: | |
| response = await litellm.acompletion( | |
| model="anthropic/claude-3.5-sonnet", | |
| messages=[{"role": "user", "content": "what is litellm?"}], | |
| vector_store_ids = [ | |
| "T37J8R4WTM" | |
| ], | |
| client=client | |
| ) | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| # Verify the LLM request was made | |
| mock_post.assert_called_once() | |
| # Verify the request body | |
| print("call args:", mock_post.call_args) | |
| request_body = mock_post.call_args.kwargs["json"] | |
| print("Request body:", json.dumps(request_body, indent=4, default=str)) | |
| # Assert content from the knowedge base was applied to the request | |
| # 1. we should have 1 content block, the first is the user message | |
| # There should only be one since there is no initialized vector store registry | |
| content = request_body["messages"][0]["content"] | |
| assert len(content) == 1 | |
| assert content[0]["type"] == "text" | |
| async def test_e2e_bedrock_knowledgebase_retrieval_with_vector_store_not_in_registry(setup_vector_store_registry): | |
| """ | |
| No vector store request is made for vector store ids that are not in the registry | |
| In this test newUnknownVectorStoreId is not in the registry, so no vector store request is made | |
| """ | |
| litellm._turn_on_debug() | |
| client = AsyncHTTPHandler() | |
| print("Registry iniitalized:", litellm.vector_store_registry.vector_stores) | |
| with patch.object(client, "post") as mock_post: | |
| # Mock the response for the LLM call | |
| mock_response = Mock() | |
| mock_response.status_code = 200 | |
| mock_response.headers = {"Content-Type": "application/json"} | |
| mock_response.json = lambda: json.loads(mock_response.text) | |
| mock_post.return_value = mock_response | |
| try: | |
| response = await litellm.acompletion( | |
| model="anthropic/claude-3.5-sonnet", | |
| messages=[{"role": "user", "content": "what is litellm?"}], | |
| vector_store_ids = [ | |
| "newUnknownVectorStoreId" | |
| ], | |
| client=client | |
| ) | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| # Verify the LLM request was made | |
| mock_post.assert_called_once() | |
| # Verify the request body | |
| print("call args:", mock_post.call_args) | |
| request_body = mock_post.call_args.kwargs["json"] | |
| print("Request body:", json.dumps(request_body, indent=4, default=str)) | |
| # Assert content from the knowedge base was applied to the request | |
| # 1. we should have 1 content block, the first is the user message | |
| # There should only be one since there is no initialized vector store registry | |
| content = request_body["messages"][0]["content"] | |
| assert len(content) == 1 | |
| assert content[0]["type"] == "text" | |