Spaces:
Runtime error
Runtime error
| from typing import Optional, Union | |
| import weave | |
| from openai import OpenAI | |
| from openai.types.chat import ChatCompletion | |
| class OpenAIModel(weave.Model): | |
| """ | |
| A class to interface with OpenAI's language models using the Weave framework. | |
| This class provides methods to create structured messages and generate predictions | |
| using OpenAI's chat completion API. It is designed to work with both single and | |
| multiple user prompts, and optionally includes a system prompt to guide the model's | |
| responses. | |
| Args: | |
| model_name (str): The name of the OpenAI model to be used for predictions. | |
| """ | |
| model_name: str | |
| _openai_client: OpenAI | |
| def __init__(self, model_name: str = "gpt-4o") -> None: | |
| super().__init__(model_name=model_name) | |
| self._openai_client = OpenAI() | |
| def create_messages( | |
| self, | |
| user_prompts: Union[str, list[str]], | |
| system_prompt: Optional[str] = None, | |
| messages: Optional[list[dict]] = None, | |
| ) -> list[dict]: | |
| """ | |
| Create a list of messages for the OpenAI chat completion API. | |
| This function constructs a list of messages in the format required by the | |
| OpenAI chat completion API. It takes user prompts, an optional system prompt, | |
| and an optional list of existing messages, and combines them into a single | |
| list of messages. | |
| Args: | |
| user_prompts (Union[str, list[str]]): A single user prompt or a list of | |
| user prompts to be included in the messages. | |
| system_prompt (Optional[str]): An optional system prompt to guide the | |
| model's responses. If provided, it will be added at the beginning | |
| of the messages list. | |
| messages (Optional[list[dict]]): An optional list of existing messages | |
| to which the new prompts will be appended. If not provided, a new | |
| list will be created. | |
| Returns: | |
| list[dict]: A list of messages formatted for the OpenAI chat completion API. | |
| """ | |
| user_prompts = [user_prompts] if isinstance(user_prompts, str) else user_prompts | |
| messages = list(messages) if isinstance(messages, dict) else [] | |
| for user_prompt in user_prompts: | |
| messages.append({"role": "user", "content": user_prompt}) | |
| if system_prompt is not None: | |
| messages = [{"role": "system", "content": system_prompt}] + messages | |
| return messages | |
| def predict( | |
| self, | |
| user_prompts: Union[str, list[str]], | |
| system_prompt: Optional[str] = None, | |
| messages: Optional[list[dict]] = None, | |
| **kwargs, | |
| ) -> ChatCompletion: | |
| """ | |
| Generate a chat completion response using the OpenAI API. | |
| This function takes user prompts, an optional system prompt, and an optional | |
| list of existing messages to create a list of messages formatted for the | |
| OpenAI chat completion API. It then sends these messages to the OpenAI API | |
| to generate a chat completion response. | |
| Args: | |
| user_prompts (Union[str, list[str]]): A single user prompt or a list of | |
| user prompts to be included in the messages. | |
| system_prompt (Optional[str]): An optional system prompt to guide the | |
| model's responses. If provided, it will be added at the beginning | |
| of the messages list. | |
| messages (Optional[list[dict]]): An optional list of existing messages | |
| to which the new prompts will be appended. If not provided, a new | |
| list will be created. | |
| **kwargs: Additional keyword arguments to be passed to the OpenAI API | |
| for chat completion. | |
| Returns: | |
| ChatCompletion: The chat completion response from the OpenAI API. | |
| """ | |
| messages = self.create_messages(user_prompts, system_prompt, messages) | |
| if "response_format" in kwargs: | |
| response = self._openai_client.beta.chat.completions.parse( | |
| model=self.model_name, messages=messages, **kwargs | |
| ) | |
| else: | |
| response = self._openai_client.chat.completions.create( | |
| model=self.model_name, messages=messages, **kwargs | |
| ) | |
| return response | |