Animator2D-v1 / gradio-interface.py
Lod_34
Add files via upload
98ebcac unverified
raw
history blame
1.61 kB
import torch
import gradio as gr
from transformers import AutoTokenizer
from torchvision import transforms
from PIL import Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Animator2D().to(device)
model.load_state_dict(torch.load("animator2D-model.pth", map_location=device))
model.eval()
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def generate_sprite(num_frames, description, action, direction):
text = f"{num_frames}-frame sprite animation of: {description}, that: {action}, facing: {direction}"
encoded_text = tokenizer(
text, padding="max_length", max_length=128, truncation=True, return_tensors="pt"
)
with torch.no_grad():
text_ids = encoded_text['input_ids'].to(device)
text_mask = encoded_text['attention_mask'].to(device)
generated_sprite = model(text_ids, text_mask).cpu().squeeze(0)
generated_sprite = (generated_sprite + 1) / 2 # Denormalizzazione
generated_sprite = transforms.ToPILImage()(generated_sprite)
return generated_sprite
iface = gr.Interface(
fn=generate_sprite,
inputs=[
gr.Number(label="Numero di Frame", value=17),
gr.Textbox(label="Descrizione dello Sprite"),
gr.Dropdown(["cammina", "corre", "salta", "attacca"], label="Azione"),
gr.Dropdown(["Nord", "Sud", "Est", "Ovest"], label="Direzione")
],
outputs=gr.Image(type="pil"),
title="Animator2D Generator",
description="Genera animazioni di sprite basate su descrizioni testuali."
)
iface.launch(share=False) # Disabilita la condivisione pubblica