Spaces:
Running
Running
# libraries | |
from __future__ import annotations | |
import json | |
import logging | |
from typing import Optional, Any, Dict | |
from colorama import Fore, Style | |
from langchain.output_parsers import PydanticOutputParser | |
from langchain.prompts import PromptTemplate | |
from ..prompts import generate_subtopics_prompt | |
from .costs import estimate_llm_cost | |
from .validators import Subtopics | |
def get_llm(llm_provider, **kwargs): | |
from gpt_researcher.llm_provider import GenericLLMProvider | |
return GenericLLMProvider.from_provider(llm_provider, **kwargs) | |
async def create_chat_completion( | |
messages: list, # type: ignore | |
model: Optional[str] = None, | |
temperature: Optional[float] = 0.4, | |
max_tokens: Optional[int] = 4000, | |
llm_provider: Optional[str] = None, | |
stream: Optional[bool] = False, | |
websocket: Any | None = None, | |
llm_kwargs: Dict[str, Any] | None = None, | |
cost_callback: callable = None | |
) -> str: | |
"""Create a chat completion using the OpenAI API | |
Args: | |
messages (list[dict[str, str]]): The messages to send to the chat completion | |
model (str, optional): The model to use. Defaults to None. | |
temperature (float, optional): The temperature to use. Defaults to 0.4. | |
max_tokens (int, optional): The max tokens to use. Defaults to 4000. | |
stream (bool, optional): Whether to stream the response. Defaults to False. | |
llm_provider (str, optional): The LLM Provider to use. | |
webocket (WebSocket): The websocket used in the currect request, | |
cost_callback: Callback function for updating cost | |
Returns: | |
str: The response from the chat completion | |
""" | |
# validate input | |
if model is None: | |
raise ValueError("Model cannot be None") | |
if max_tokens is not None and max_tokens > 16001: | |
raise ValueError( | |
f"Max tokens cannot be more than 16,000, but got {max_tokens}") | |
# Get the provider from supported providers | |
provider = get_llm(llm_provider, model=model, temperature=temperature, | |
max_tokens=max_tokens, **(llm_kwargs or {})) | |
response = "" | |
# create response | |
for _ in range(10): # maximum of 10 attempts | |
response = await provider.get_chat_response( | |
messages, stream, websocket | |
) | |
if cost_callback: | |
llm_costs = estimate_llm_cost(str(messages), response) | |
cost_callback(llm_costs) | |
return response | |
logging.error(f"Failed to get response from {llm_provider} API") | |
raise RuntimeError(f"Failed to get response from {llm_provider} API") | |
async def construct_subtopics(task: str, data: str, config, subtopics: list = []) -> list: | |
""" | |
Construct subtopics based on the given task and data. | |
Args: | |
task (str): The main task or topic. | |
data (str): Additional data for context. | |
config: Configuration settings. | |
subtopics (list, optional): Existing subtopics. Defaults to []. | |
Returns: | |
list: A list of constructed subtopics. | |
""" | |
try: | |
parser = PydanticOutputParser(pydantic_object=Subtopics) | |
prompt = PromptTemplate( | |
template=generate_subtopics_prompt(), | |
input_variables=["task", "data", "subtopics", "max_subtopics"], | |
partial_variables={ | |
"format_instructions": parser.get_format_instructions()}, | |
) | |
print(f"\n🤖 Calling {config.smart_llm_model}...\n") | |
temperature = config.temperature | |
# temperature = 0 # Note: temperature throughout the code base is currently set to Zero | |
provider = get_llm( | |
config.smart_llm_provider, | |
model=config.smart_llm_model, | |
temperature=temperature, | |
max_tokens=config.smart_token_limit, | |
**config.llm_kwargs, | |
) | |
model = provider.llm | |
chain = prompt | model | parser | |
output = chain.invoke({ | |
"task": task, | |
"data": data, | |
"subtopics": subtopics, | |
"max_subtopics": config.max_subtopics | |
}) | |
return output | |
except Exception as e: | |
print("Exception in parsing subtopics : ", e) | |
return subtopics | |