Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import io | |
| from PIL import Image | |
| import requests | |
| import random | |
| import dom | |
| import os | |
| NUM_IMAGES = 2 | |
| # Configuración del dispositivo | |
| device = "cpu" | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| elif torch.backends.mps.is_available(): | |
| device = "mps" | |
| # Configuración de modelos | |
| API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev" | |
| headers = {"Authorization": f"Bearer {os.getenv('api_token')}"} | |
| model_id_image_description = "vikhyatk/moondream2" | |
| revision = "2024-08-26" | |
| torch_dtype = torch.float32 | |
| if torch.cuda.is_available(): | |
| torch_dtype = torch.bfloat16 # Optimización en GPU | |
| # Carga de modelos persistente | |
| print("Cargando modelo de descripción de imágenes...") | |
| model_description = AutoModelForCausalLM.from_pretrained(model_id_image_description, trust_remote_code=True, revision=revision) | |
| tokenizer_description = AutoTokenizer.from_pretrained(model_id_image_description, revision=revision) | |
| def generate_description(image_path): | |
| image_test = Image.open(image_path) | |
| enc_image = model_description.encode_image(image_test) | |
| description = model_description.answer_question(enc_image, "Describe this image to create an avatar", tokenizer_description) | |
| return description | |
| def query(payload): | |
| response = requests.post(API_URL, headers=headers, json=payload) | |
| return response.content | |
| def generate_image_by_description(description, avatar_style=None): | |
| images = [] | |
| for _ in range(NUM_IMAGES): | |
| prompt = f"Create a pigeon profile avatar. Use the following description: {description}." | |
| if avatar_style: | |
| prompt += f" Use {avatar_style} style." | |
| image_bytes = query({"inputs": prompt, "parameters": {"seed": random.randint(0, 1000)}}) | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| images.append(image) | |
| print(images) | |
| return images | |
| def process_and_generate(image, avatar_style): | |
| description = generate_description(image) | |
| return generate_image_by_description(description, avatar_style) | |
| with gr.Blocks(js=dom.generate_title) as demo: | |
| with gr.Row(): | |
| gr.Markdown(dom.generate_markdown) | |
| with gr.Row(): | |
| with gr.Column(scale=2, min_width=300): | |
| selected_image = gr.Image(type="filepath", label="Upload an Image of the Pigeon", height=300) | |
| example_image = gr.Examples(["./examples/pigeon.webp"], label="Example Images", inputs=[selected_image]) | |
| avatar_style = gr.Radio( | |
| ["Realistic", "Pixel Art", "Imaginative", "Cartoon"], | |
| label="(optional) Select the avatar style:" | |
| ) | |
| generate_button = gr.Button("Generate Avatar", variant="primary") | |
| with gr.Column(scale=2, min_width=300): | |
| generated_image = gr.Gallery(type="pil", label="Generated Avatar", height=300) | |
| generate_button.click(process_and_generate, inputs=[selected_image, avatar_style], outputs=generated_image) | |
| demo.launch() | |