binhnase04854's picture
first deploy
b699122
raw
history blame
4.27 kB
"""Wrapper functions around an LLM chain."""
import logging
from typing import Any, List, Optional, Union
import openai
from langchain import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import (
BaseMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseLanguageModel, BaseMessage
from gpt_index.llm_predictor.base import LLMPredictor
from gpt_index.prompts.base import Prompt
from gpt_index.utils import ErrorToRetry, retry_on_exceptions_with_backoff
logger = logging.getLogger(__name__)
class ChatGPTLLMPredictor(LLMPredictor):
"""ChatGPT Specific LLM predictor class.
Wrapper around an LLMPredictor to provide ChatGPT specific features.
Args:
llm (Optional[langchain.llms.base.LLM]): LLM from Langchain to use
for predictions. Defaults to OpenAI's text-davinci-003 model.
Please see `Langchain's LLM Page
<https://langchain.readthedocs.io/en/latest/modules/llms.html>`_
for more details.
retry_on_throttling (bool): Whether to retry on rate limit errors.
Defaults to true.
"""
def __init__(
self,
llm: Optional[BaseLanguageModel] = None,
prepend_messages: Optional[
List[Union[BaseMessagePromptTemplate, BaseMessage]]
] = None,
**kwargs: Any
) -> None:
"""Initialize params."""
super().__init__(
llm=llm or ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo"), **kwargs
)
self.prepend_messages = prepend_messages
def _get_langchain_prompt(
self, prompt: Prompt
) -> Union[ChatPromptTemplate, BasePromptTemplate]:
"""Add prepend_messages to prompt."""
lc_prompt = prompt.get_langchain_prompt(llm=self._llm)
if self.prepend_messages:
if isinstance(lc_prompt, PromptTemplate):
msgs = self.prepend_messages + [
HumanMessagePromptTemplate.from_template(lc_prompt.template)
]
lc_prompt = ChatPromptTemplate.from_messages(msgs)
elif isinstance(lc_prompt, ChatPromptTemplate):
lc_prompt.messages = self.prepend_messages + lc_prompt.messages
return lc_prompt
def _predict(self, prompt: Prompt, **prompt_args: Any) -> str:
"""Inner predict function.
If retry_on_throttling is true, we will retry on rate limit errors.
"""
lc_prompt = self._get_langchain_prompt(prompt)
llm_chain = LLMChain(prompt=lc_prompt, llm=self._llm)
# Note: we don't pass formatted_prompt to llm_chain.predict because
# langchain does the same formatting under the hood
full_prompt_args = prompt.get_full_format_args(prompt_args)
if self.retry_on_throttling:
llm_prediction = retry_on_exceptions_with_backoff(
lambda: llm_chain.predict(**full_prompt_args),
[
ErrorToRetry(openai.error.RateLimitError),
ErrorToRetry(openai.error.ServiceUnavailableError),
ErrorToRetry(openai.error.TryAgain),
ErrorToRetry(
openai.error.APIConnectionError, lambda e: e.should_retry
),
],
)
else:
llm_prediction = llm_chain.predict(**full_prompt_args)
return llm_prediction
async def _apredict(self, prompt: Prompt, **prompt_args: Any) -> str:
"""Async inner predict function.
If retry_on_throttling is true, we will retry on rate limit errors.
"""
lc_prompt = self._get_langchain_prompt(prompt)
llm_chain = LLMChain(prompt=lc_prompt, llm=self._llm)
# Note: we don't pass formatted_prompt to llm_chain.predict because
# langchain does the same formatting under the hood
full_prompt_args = prompt.get_full_format_args(prompt_args)
# TODO: support retry on throttling
llm_prediction = await llm_chain.apredict(**full_prompt_args)
return llm_prediction