"""Wrapper functions around an LLM chain.""" import logging from typing import Any, List, Optional, Union import openai from langchain import LLMChain from langchain.chat_models import ChatOpenAI from langchain.prompts.base import BasePromptTemplate from langchain.prompts.chat import ( BaseMessagePromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate, ) from langchain.prompts.prompt import PromptTemplate from langchain.schema import BaseLanguageModel, BaseMessage from gpt_index.llm_predictor.base import LLMPredictor from gpt_index.prompts.base import Prompt from gpt_index.utils import ErrorToRetry, retry_on_exceptions_with_backoff logger = logging.getLogger(__name__) class ChatGPTLLMPredictor(LLMPredictor): """ChatGPT Specific LLM predictor class. Wrapper around an LLMPredictor to provide ChatGPT specific features. Args: llm (Optional[langchain.llms.base.LLM]): LLM from Langchain to use for predictions. Defaults to OpenAI's text-davinci-003 model. Please see `Langchain's LLM Page `_ for more details. retry_on_throttling (bool): Whether to retry on rate limit errors. Defaults to true. """ def __init__( self, llm: Optional[BaseLanguageModel] = None, prepend_messages: Optional[ List[Union[BaseMessagePromptTemplate, BaseMessage]] ] = None, **kwargs: Any ) -> None: """Initialize params.""" super().__init__( llm=llm or ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo"), **kwargs ) self.prepend_messages = prepend_messages def _get_langchain_prompt( self, prompt: Prompt ) -> Union[ChatPromptTemplate, BasePromptTemplate]: """Add prepend_messages to prompt.""" lc_prompt = prompt.get_langchain_prompt(llm=self._llm) if self.prepend_messages: if isinstance(lc_prompt, PromptTemplate): msgs = self.prepend_messages + [ HumanMessagePromptTemplate.from_template(lc_prompt.template) ] lc_prompt = ChatPromptTemplate.from_messages(msgs) elif isinstance(lc_prompt, ChatPromptTemplate): lc_prompt.messages = self.prepend_messages + lc_prompt.messages return lc_prompt def _predict(self, prompt: Prompt, **prompt_args: Any) -> str: """Inner predict function. If retry_on_throttling is true, we will retry on rate limit errors. """ lc_prompt = self._get_langchain_prompt(prompt) llm_chain = LLMChain(prompt=lc_prompt, llm=self._llm) # Note: we don't pass formatted_prompt to llm_chain.predict because # langchain does the same formatting under the hood full_prompt_args = prompt.get_full_format_args(prompt_args) if self.retry_on_throttling: llm_prediction = retry_on_exceptions_with_backoff( lambda: llm_chain.predict(**full_prompt_args), [ ErrorToRetry(openai.error.RateLimitError), ErrorToRetry(openai.error.ServiceUnavailableError), ErrorToRetry(openai.error.TryAgain), ErrorToRetry( openai.error.APIConnectionError, lambda e: e.should_retry ), ], ) else: llm_prediction = llm_chain.predict(**full_prompt_args) return llm_prediction async def _apredict(self, prompt: Prompt, **prompt_args: Any) -> str: """Async inner predict function. If retry_on_throttling is true, we will retry on rate limit errors. """ lc_prompt = self._get_langchain_prompt(prompt) llm_chain = LLMChain(prompt=lc_prompt, llm=self._llm) # Note: we don't pass formatted_prompt to llm_chain.predict because # langchain does the same formatting under the hood full_prompt_args = prompt.get_full_format_args(prompt_args) # TODO: support retry on throttling llm_prediction = await llm_chain.apredict(**full_prompt_args) return llm_prediction