Spaces:
Running
on
Zero
Running
on
Zero
| from openai import OpenAI | |
| from PIL import Image | |
| import requests | |
| import io | |
| import os | |
| import base64 | |
| class OpenaiModel(): | |
| def __init__(self, model_name, model_type): | |
| self.model_name = model_name | |
| self.model_type = model_type | |
| def __call__(self, *args, **kwargs): | |
| if self.model_type == "text2image": | |
| assert "prompt" in kwargs, "prompt is required for text2image model" | |
| client = OpenAI() | |
| if 'Dalle-3' in self.model_name: | |
| client = OpenAI() | |
| response = client.images.generate( | |
| model="dall-e-3", | |
| prompt=kwargs["prompt"], | |
| size="1024x1024", | |
| quality="standard", | |
| n=1, | |
| ) | |
| elif 'Dalle-2' in self.model_name: | |
| client = OpenAI() | |
| response = client.images.generate( | |
| model="dall-e-2", | |
| prompt=kwargs["prompt"], | |
| size="512x512", | |
| quality="standard", | |
| n=1, | |
| ) | |
| else: | |
| raise NotImplementedError | |
| result_url = response.data[0].url | |
| response = requests.get(result_url) | |
| result = Image.open(io.BytesIO(response.content)) | |
| return result | |
| else: | |
| raise ValueError("model_type must be text2image or image2image") | |
| def load_openai_model(model_name, model_type): | |
| return OpenaiModel(model_name, model_type) | |
| if __name__ == "__main__": | |
| pipe = load_openai_model('Dalle-2', 'text2image') | |
| result = pipe(prompt='draw a tiger') | |
| print(result) | |