Spaces:
Runtime error
Runtime error
File size: 4,269 Bytes
b699122 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
"""Wrapper functions around an LLM chain."""
import logging
from typing import Any, List, Optional, Union
import openai
from langchain import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import (
BaseMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseLanguageModel, BaseMessage
from gpt_index.llm_predictor.base import LLMPredictor
from gpt_index.prompts.base import Prompt
from gpt_index.utils import ErrorToRetry, retry_on_exceptions_with_backoff
logger = logging.getLogger(__name__)
class ChatGPTLLMPredictor(LLMPredictor):
"""ChatGPT Specific LLM predictor class.
Wrapper around an LLMPredictor to provide ChatGPT specific features.
Args:
llm (Optional[langchain.llms.base.LLM]): LLM from Langchain to use
for predictions. Defaults to OpenAI's text-davinci-003 model.
Please see `Langchain's LLM Page
<https://langchain.readthedocs.io/en/latest/modules/llms.html>`_
for more details.
retry_on_throttling (bool): Whether to retry on rate limit errors.
Defaults to true.
"""
def __init__(
self,
llm: Optional[BaseLanguageModel] = None,
prepend_messages: Optional[
List[Union[BaseMessagePromptTemplate, BaseMessage]]
] = None,
**kwargs: Any
) -> None:
"""Initialize params."""
super().__init__(
llm=llm or ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo"), **kwargs
)
self.prepend_messages = prepend_messages
def _get_langchain_prompt(
self, prompt: Prompt
) -> Union[ChatPromptTemplate, BasePromptTemplate]:
"""Add prepend_messages to prompt."""
lc_prompt = prompt.get_langchain_prompt(llm=self._llm)
if self.prepend_messages:
if isinstance(lc_prompt, PromptTemplate):
msgs = self.prepend_messages + [
HumanMessagePromptTemplate.from_template(lc_prompt.template)
]
lc_prompt = ChatPromptTemplate.from_messages(msgs)
elif isinstance(lc_prompt, ChatPromptTemplate):
lc_prompt.messages = self.prepend_messages + lc_prompt.messages
return lc_prompt
def _predict(self, prompt: Prompt, **prompt_args: Any) -> str:
"""Inner predict function.
If retry_on_throttling is true, we will retry on rate limit errors.
"""
lc_prompt = self._get_langchain_prompt(prompt)
llm_chain = LLMChain(prompt=lc_prompt, llm=self._llm)
# Note: we don't pass formatted_prompt to llm_chain.predict because
# langchain does the same formatting under the hood
full_prompt_args = prompt.get_full_format_args(prompt_args)
if self.retry_on_throttling:
llm_prediction = retry_on_exceptions_with_backoff(
lambda: llm_chain.predict(**full_prompt_args),
[
ErrorToRetry(openai.error.RateLimitError),
ErrorToRetry(openai.error.ServiceUnavailableError),
ErrorToRetry(openai.error.TryAgain),
ErrorToRetry(
openai.error.APIConnectionError, lambda e: e.should_retry
),
],
)
else:
llm_prediction = llm_chain.predict(**full_prompt_args)
return llm_prediction
async def _apredict(self, prompt: Prompt, **prompt_args: Any) -> str:
"""Async inner predict function.
If retry_on_throttling is true, we will retry on rate limit errors.
"""
lc_prompt = self._get_langchain_prompt(prompt)
llm_chain = LLMChain(prompt=lc_prompt, llm=self._llm)
# Note: we don't pass formatted_prompt to llm_chain.predict because
# langchain does the same formatting under the hood
full_prompt_args = prompt.get_full_format_args(prompt_args)
# TODO: support retry on throttling
llm_prediction = await llm_chain.apredict(**full_prompt_args)
return llm_prediction
|