import torch import gradio as gr from transformers import AutoTokenizer from torchvision import transforms from PIL import Image device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Animator2D().to(device) model.load_state_dict(torch.load("animator2D-model.pth", map_location=device)) model.eval() tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") def generate_sprite(num_frames, description, action, direction): text = f"{num_frames}-frame sprite animation of: {description}, that: {action}, facing: {direction}" encoded_text = tokenizer( text, padding="max_length", max_length=128, truncation=True, return_tensors="pt" ) with torch.no_grad(): text_ids = encoded_text['input_ids'].to(device) text_mask = encoded_text['attention_mask'].to(device) generated_sprite = model(text_ids, text_mask).cpu().squeeze(0) generated_sprite = (generated_sprite + 1) / 2 # Denormalizzazione generated_sprite = transforms.ToPILImage()(generated_sprite) return generated_sprite iface = gr.Interface( fn=generate_sprite, inputs=[ gr.Number(label="Numero di Frame", value=17), gr.Textbox(label="Descrizione dello Sprite"), gr.Dropdown(["cammina", "corre", "salta", "attacca"], label="Azione"), gr.Dropdown(["Nord", "Sud", "Est", "Ovest"], label="Direzione") ], outputs=gr.Image(type="pil"), title="Animator2D Generator", description="Genera animazioni di sprite basate su descrizioni testuali." ) iface.launch(share=False) # Disabilita la condivisione pubblica