|
import time |
|
import openai |
|
from typing import Union, List, Dict |
|
|
|
|
|
def get_gpt_output(message: Union[str, List[Dict[str, str]]], |
|
model: str = "gpt-4-1106-preview", |
|
max_tokens: int = 2048, |
|
temperature: float = 1.0, |
|
max_retry: int = 1, |
|
sleep_time: int = 60, |
|
json_object: bool = False) -> str: |
|
""" |
|
Call the OpenAI API to get the GPT model output for a given prompt. |
|
|
|
Args: |
|
message (Union[str, List[Dict[str, str]]]): The input message or a list of message dicts. |
|
model (str): The model to use for completion. Default is "gpt-4-1106-preview". |
|
max_tokens (int): Maximum number of tokens to generate. Default is 2048. |
|
temperature (float): Sampling temperature. Default is 1.0. |
|
max_retry (int): Maximum number of retries in case of an error. Default is 1. |
|
sleep_time (int): Sleep time between retries in seconds. Default is 60. |
|
json_object (bool): Whether to output in JSON format. Default is False. |
|
|
|
Returns: |
|
str: The completed text generated by the GPT model. |
|
|
|
Raises: |
|
Exception: If the completion fails after the maximum number of retries. |
|
""" |
|
if json_object: |
|
if isinstance(message, str) and 'json' not in message.lower(): |
|
message = 'You are a helpful assistant designed to output JSON. ' + message |
|
|
|
if isinstance(message, str): |
|
messages = [{"role": "user", "content": message}] |
|
else: |
|
messages = message |
|
|
|
kwargs = {"response_format": {"type": "json_object"}} if json_object else {} |
|
|
|
for cnt in range(max_retry): |
|
try: |
|
chat = openai.OpenAI().chat.completions.create( |
|
messages=messages, |
|
model=model, |
|
temperature=temperature, |
|
max_tokens=max_tokens, |
|
**kwargs |
|
) |
|
return chat.choices[0].message.content |
|
except Exception as e: |
|
print(f"Attempt {cnt} failed: {e}. Retrying after {sleep_time} seconds...") |
|
time.sleep(sleep_time) |
|
|
|
raise Exception("Failed to get GPT output after maximum retries") |
|
|