Animator2D-v2 / app.py
Lod34's picture
Update app.py
57ee356 verified
import gradio as gr
import torch
from PIL import Image
import os
from transformers import AutoTokenizer, AutoModel, T5ForConditionalGeneration
from huggingface_hub import hf_hub_download
import torch.nn as nn
class SpriteGenerator(nn.Module):
def __init__(self, text_encoder_name="t5-base", latent_dim=512):
super(SpriteGenerator, self).__init__()
# Text encoder (T5 with lm_head)
self.text_encoder = T5ForConditionalGeneration.from_pretrained(text_encoder_name)
for param in self.text_encoder.parameters():
param.requires_grad = False
# Proiezione dal testo al latent space
self.text_projection = nn.Sequential(
nn.Linear(768, latent_dim),
nn.LeakyReLU(0.2),
nn.Linear(latent_dim, latent_dim)
)
# Generator
self.generator = nn.Sequential(
# Input: latent_dim x 1 x 1 -> 512 x 4 x 4
nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
# 512 x 4 x 4 -> 256 x 8 x 8
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
# 256 x 8 x 8 -> 128 x 16 x 16
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
# 128 x 16 x 16 -> 64 x 32 x 32
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
# 64 x 32 x 32 -> 32 x 64 x 64
nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(True),
# 32 x 64 x 64 -> 16 x 128 x 128
nn.ConvTranspose2d(32, 16, 4, 2, 1, bias=False),
nn.BatchNorm2d(16),
nn.ReLU(True),
# 16 x 128 x 128 -> 3 x 256 x 256
nn.ConvTranspose2d(16, 3, 4, 2, 1, bias=False),
)
# Frame interpolator
self.frame_interpolator = nn.Sequential(
nn.Linear(latent_dim + 1, latent_dim),
nn.LeakyReLU(0.2),
nn.Linear(latent_dim, latent_dim),
nn.LeakyReLU(0.2)
)
def forward(self, input_ids, attention_mask, num_frames=1):
batch_size = input_ids.shape[0]
# Encode text usando il T5 completo
text_outputs = self.text_encoder.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
# Get text features
text_features = text_outputs.last_hidden_state.mean(dim=1)
# Project to latent space
latent_vector = self.text_projection(text_features)
# Generate multiple frames if needed
all_frames = []
for frame_idx in range(max(num_frames.max().item(), 1)):
frame_info = torch.ones((batch_size, 1), device=latent_vector.device) * frame_idx / max(num_frames.max().item(), 1)
# Combine latent vector with frame info
frame_latent = self.frame_interpolator(
torch.cat([latent_vector, frame_info], dim=1)
)
# Generate frame
frame_latent_reshaped = frame_latent.unsqueeze(2).unsqueeze(3)
frame = self.generator(frame_latent_reshaped)
frame = torch.tanh(frame)
all_frames.append(frame)
# Stack all frames
sprites = torch.stack(all_frames, dim=1)
return sprites
def initialize_model():
print("Inizializzazione del modello...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SpriteGenerator()
try:
# Scarica il modello da Hugging Face Hub
model_path = hf_hub_download(
repo_id="Lod34/Animator2D-v2",
filename="pytorch_model.bin",
repo_type="model"
)
# Carica il modello
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()
print(f"Modello caricato con successo su {device}!")
return model, device
except Exception as e:
print(f"Errore nel caricamento del modello: {str(e)}")
raise
def generate_sprite(prompt, num_frames=8):
try:
# Usa il modello e il device globali
global model, device, tokenizer
# Tokenizza il testo
tokens = tokenizer(prompt, return_tensors="pt", padding=True)
tokens = {k: v.to(device) for k, v in tokens.items()}
# Genera l'immagine
with torch.no_grad():
frames = model(
input_ids=tokens["input_ids"],
attention_mask=tokens["attention_mask"],
num_frames=torch.tensor([num_frames], device=device)
)
# Converte il tensore in immagine
frames = (frames * 0.5 + 0.5).clamp(0, 1)
frames = frames.cpu().numpy()
# Ritorna il primo frame come esempio
frame = frames[0, 0] # Prende il primo frame del batch
frame = (frame * 255).astype('uint8').transpose(1, 2, 0)
return Image.fromarray(frame)
except Exception as e:
print(f"Errore nella generazione: {str(e)}")
raise
# Inizializzazione globale
print("Caricamento del modello e configurazione dell'interfaccia...")
try:
# Inizializzazione del modello e del tokenizer
model, device = initialize_model()
tokenizer = AutoTokenizer.from_pretrained("t5-base")
# Configurazione dell'interfaccia Gradio
interface = gr.Interface(
fn=generate_sprite,
inputs=[
gr.Textbox(
label="Descrivi lo sprite che vuoi generare",
placeholder="Esempio: un personaggio pixel art che cammina"
),
gr.Slider(
minimum=1,
maximum=16,
value=8,
step=1,
label="Numero di frame",
info="Più frame = animazione più fluida ma generazione più lenta"
)
],
outputs=gr.Image(label="Sprite generato"),
title="🎮 Animator2D-v2 Sprite Generator",
description="""
## Generatore di Sprite Animati
Questo strumento genera sprite pixel art da descrizioni testuali.
### Come usare:
1. Inserisci una descrizione dello sprite che vuoi generare
2. Regola il numero di frame desiderati
3. Clicca su Submit e attendi la generazione
### Tips:
- Sii specifico nella descrizione
- Prova diversi numeri di frame per risultati diversi
- Le descrizioni in inglese potrebbero funzionare meglio
""",
article="""
### Note:
- La generazione può richiedere alcuni secondi
- Vengono mostrati solo i primi frame dell'animazione
- Per risultati migliori, usa descrizioni dettagliate
Creato da [Lod34](https://huggingface.co/Lod34)
"""
)
# Avvio dell'interfaccia
interface.launch()
except Exception as e:
print(f"Errore nell'inizializzazione dell'applicazione: {str(e)}")
raise