Spaces:
No application file
No application file
import torch | |
import gradio as gr | |
from transformers import AutoTokenizer | |
from torchvision import transforms | |
from PIL import Image | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = Animator2D().to(device) | |
model.load_state_dict(torch.load("animator2D-model.pth", map_location=device)) | |
model.eval() | |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
def generate_sprite(num_frames, description, action, direction): | |
text = f"{num_frames}-frame sprite animation of: {description}, that: {action}, facing: {direction}" | |
encoded_text = tokenizer( | |
text, padding="max_length", max_length=128, truncation=True, return_tensors="pt" | |
) | |
with torch.no_grad(): | |
text_ids = encoded_text['input_ids'].to(device) | |
text_mask = encoded_text['attention_mask'].to(device) | |
generated_sprite = model(text_ids, text_mask).cpu().squeeze(0) | |
generated_sprite = (generated_sprite + 1) / 2 # Denormalizzazione | |
generated_sprite = transforms.ToPILImage()(generated_sprite) | |
return generated_sprite | |
iface = gr.Interface( | |
fn=generate_sprite, | |
inputs=[ | |
gr.Number(label="Numero di Frame", value=17), | |
gr.Textbox(label="Descrizione dello Sprite"), | |
gr.Dropdown(["cammina", "corre", "salta", "attacca"], label="Azione"), | |
gr.Dropdown(["Nord", "Sud", "Est", "Ovest"], label="Direzione") | |
], | |
outputs=gr.Image(type="pil"), | |
title="Animator2D Generator", | |
description="Genera animazioni di sprite basate su descrizioni testuali." | |
) | |
iface.launch(share=False) # Disabilita la condivisione pubblica | |