File size: 4,269 Bytes
b699122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Wrapper functions around an LLM chain."""

import logging
from typing import Any, List, Optional, Union

import openai
from langchain import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import (
    BaseMessagePromptTemplate,
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseLanguageModel, BaseMessage

from gpt_index.llm_predictor.base import LLMPredictor
from gpt_index.prompts.base import Prompt
from gpt_index.utils import ErrorToRetry, retry_on_exceptions_with_backoff

logger = logging.getLogger(__name__)


class ChatGPTLLMPredictor(LLMPredictor):
    """ChatGPT Specific LLM predictor class.

    Wrapper around an LLMPredictor to provide ChatGPT specific features.

    Args:
        llm (Optional[langchain.llms.base.LLM]): LLM from Langchain to use
            for predictions. Defaults to OpenAI's text-davinci-003 model.
            Please see `Langchain's LLM Page
            <https://langchain.readthedocs.io/en/latest/modules/llms.html>`_
            for more details.

        retry_on_throttling (bool): Whether to retry on rate limit errors.
            Defaults to true.

    """

    def __init__(
        self,
        llm: Optional[BaseLanguageModel] = None,
        prepend_messages: Optional[
            List[Union[BaseMessagePromptTemplate, BaseMessage]]
        ] = None,
        **kwargs: Any
    ) -> None:
        """Initialize params."""
        super().__init__(
            llm=llm or ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo"), **kwargs
        )
        self.prepend_messages = prepend_messages

    def _get_langchain_prompt(
        self, prompt: Prompt
    ) -> Union[ChatPromptTemplate, BasePromptTemplate]:
        """Add prepend_messages to prompt."""
        lc_prompt = prompt.get_langchain_prompt(llm=self._llm)
        if self.prepend_messages:
            if isinstance(lc_prompt, PromptTemplate):
                msgs = self.prepend_messages + [
                    HumanMessagePromptTemplate.from_template(lc_prompt.template)
                ]
                lc_prompt = ChatPromptTemplate.from_messages(msgs)
            elif isinstance(lc_prompt, ChatPromptTemplate):
                lc_prompt.messages = self.prepend_messages + lc_prompt.messages

        return lc_prompt

    def _predict(self, prompt: Prompt, **prompt_args: Any) -> str:
        """Inner predict function.

        If retry_on_throttling is true, we will retry on rate limit errors.

        """
        lc_prompt = self._get_langchain_prompt(prompt)
        llm_chain = LLMChain(prompt=lc_prompt, llm=self._llm)

        # Note: we don't pass formatted_prompt to llm_chain.predict because
        # langchain does the same formatting under the hood
        full_prompt_args = prompt.get_full_format_args(prompt_args)
        if self.retry_on_throttling:
            llm_prediction = retry_on_exceptions_with_backoff(
                lambda: llm_chain.predict(**full_prompt_args),
                [
                    ErrorToRetry(openai.error.RateLimitError),
                    ErrorToRetry(openai.error.ServiceUnavailableError),
                    ErrorToRetry(openai.error.TryAgain),
                    ErrorToRetry(
                        openai.error.APIConnectionError, lambda e: e.should_retry
                    ),
                ],
            )
        else:
            llm_prediction = llm_chain.predict(**full_prompt_args)
        return llm_prediction

    async def _apredict(self, prompt: Prompt, **prompt_args: Any) -> str:
        """Async inner predict function.

        If retry_on_throttling is true, we will retry on rate limit errors.

        """
        lc_prompt = self._get_langchain_prompt(prompt)
        llm_chain = LLMChain(prompt=lc_prompt, llm=self._llm)

        # Note: we don't pass formatted_prompt to llm_chain.predict because
        # langchain does the same formatting under the hood
        full_prompt_args = prompt.get_full_format_args(prompt_args)
        # TODO: support retry on throttling
        llm_prediction = await llm_chain.apredict(**full_prompt_args)
        return llm_prediction