Spaces:
Running
Running
import gradio as gr | |
import torch | |
from PIL import Image | |
import os | |
from transformers import AutoTokenizer, AutoModel, T5ForConditionalGeneration | |
from huggingface_hub import hf_hub_download | |
import torch.nn as nn | |
class SpriteGenerator(nn.Module): | |
def __init__(self, text_encoder_name="t5-base", latent_dim=512): | |
super(SpriteGenerator, self).__init__() | |
# Text encoder (T5 with lm_head) | |
self.text_encoder = T5ForConditionalGeneration.from_pretrained(text_encoder_name) | |
for param in self.text_encoder.parameters(): | |
param.requires_grad = False | |
# Proiezione dal testo al latent space | |
self.text_projection = nn.Sequential( | |
nn.Linear(768, latent_dim), | |
nn.LeakyReLU(0.2), | |
nn.Linear(latent_dim, latent_dim) | |
) | |
# Generator | |
self.generator = nn.Sequential( | |
# Input: latent_dim x 1 x 1 -> 512 x 4 x 4 | |
nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False), | |
nn.BatchNorm2d(512), | |
nn.ReLU(True), | |
# 512 x 4 x 4 -> 256 x 8 x 8 | |
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(256), | |
nn.ReLU(True), | |
# 256 x 8 x 8 -> 128 x 16 x 16 | |
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(128), | |
nn.ReLU(True), | |
# 128 x 16 x 16 -> 64 x 32 x 32 | |
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(64), | |
nn.ReLU(True), | |
# 64 x 32 x 32 -> 32 x 64 x 64 | |
nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(32), | |
nn.ReLU(True), | |
# 32 x 64 x 64 -> 16 x 128 x 128 | |
nn.ConvTranspose2d(32, 16, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(16), | |
nn.ReLU(True), | |
# 16 x 128 x 128 -> 3 x 256 x 256 | |
nn.ConvTranspose2d(16, 3, 4, 2, 1, bias=False), | |
) | |
# Frame interpolator | |
self.frame_interpolator = nn.Sequential( | |
nn.Linear(latent_dim + 1, latent_dim), | |
nn.LeakyReLU(0.2), | |
nn.Linear(latent_dim, latent_dim), | |
nn.LeakyReLU(0.2) | |
) | |
def forward(self, input_ids, attention_mask, num_frames=1): | |
batch_size = input_ids.shape[0] | |
# Encode text usando il T5 completo | |
text_outputs = self.text_encoder.encoder( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
return_dict=True | |
) | |
# Get text features | |
text_features = text_outputs.last_hidden_state.mean(dim=1) | |
# Project to latent space | |
latent_vector = self.text_projection(text_features) | |
# Generate multiple frames if needed | |
all_frames = [] | |
for frame_idx in range(max(num_frames.max().item(), 1)): | |
frame_info = torch.ones((batch_size, 1), device=latent_vector.device) * frame_idx / max(num_frames.max().item(), 1) | |
# Combine latent vector with frame info | |
frame_latent = self.frame_interpolator( | |
torch.cat([latent_vector, frame_info], dim=1) | |
) | |
# Generate frame | |
frame_latent_reshaped = frame_latent.unsqueeze(2).unsqueeze(3) | |
frame = self.generator(frame_latent_reshaped) | |
frame = torch.tanh(frame) | |
all_frames.append(frame) | |
# Stack all frames | |
sprites = torch.stack(all_frames, dim=1) | |
return sprites | |
def initialize_model(): | |
print("Inizializzazione del modello...") | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = SpriteGenerator() | |
try: | |
# Scarica il modello da Hugging Face Hub | |
model_path = hf_hub_download( | |
repo_id="Lod34/Animator2D-v2", | |
filename="pytorch_model.bin", | |
repo_type="model" | |
) | |
# Carica il modello | |
state_dict = torch.load(model_path, map_location=device) | |
model.load_state_dict(state_dict) | |
model = model.to(device) | |
model.eval() | |
print(f"Modello caricato con successo su {device}!") | |
return model, device | |
except Exception as e: | |
print(f"Errore nel caricamento del modello: {str(e)}") | |
raise | |
def generate_sprite(prompt, num_frames=8): | |
try: | |
# Usa il modello e il device globali | |
global model, device, tokenizer | |
# Tokenizza il testo | |
tokens = tokenizer(prompt, return_tensors="pt", padding=True) | |
tokens = {k: v.to(device) for k, v in tokens.items()} | |
# Genera l'immagine | |
with torch.no_grad(): | |
frames = model( | |
input_ids=tokens["input_ids"], | |
attention_mask=tokens["attention_mask"], | |
num_frames=torch.tensor([num_frames], device=device) | |
) | |
# Converte il tensore in immagine | |
frames = (frames * 0.5 + 0.5).clamp(0, 1) | |
frames = frames.cpu().numpy() | |
# Ritorna il primo frame come esempio | |
frame = frames[0, 0] # Prende il primo frame del batch | |
frame = (frame * 255).astype('uint8').transpose(1, 2, 0) | |
return Image.fromarray(frame) | |
except Exception as e: | |
print(f"Errore nella generazione: {str(e)}") | |
raise | |
# Inizializzazione globale | |
print("Caricamento del modello e configurazione dell'interfaccia...") | |
try: | |
# Inizializzazione del modello e del tokenizer | |
model, device = initialize_model() | |
tokenizer = AutoTokenizer.from_pretrained("t5-base") | |
# Configurazione dell'interfaccia Gradio | |
interface = gr.Interface( | |
fn=generate_sprite, | |
inputs=[ | |
gr.Textbox( | |
label="Descrivi lo sprite che vuoi generare", | |
placeholder="Esempio: un personaggio pixel art che cammina" | |
), | |
gr.Slider( | |
minimum=1, | |
maximum=16, | |
value=8, | |
step=1, | |
label="Numero di frame", | |
info="Più frame = animazione più fluida ma generazione più lenta" | |
) | |
], | |
outputs=gr.Image(label="Sprite generato"), | |
title="🎮 Animator2D-v2 Sprite Generator", | |
description=""" | |
## Generatore di Sprite Animati | |
Questo strumento genera sprite pixel art da descrizioni testuali. | |
### Come usare: | |
1. Inserisci una descrizione dello sprite che vuoi generare | |
2. Regola il numero di frame desiderati | |
3. Clicca su Submit e attendi la generazione | |
### Tips: | |
- Sii specifico nella descrizione | |
- Prova diversi numeri di frame per risultati diversi | |
- Le descrizioni in inglese potrebbero funzionare meglio | |
""", | |
article=""" | |
### Note: | |
- La generazione può richiedere alcuni secondi | |
- Vengono mostrati solo i primi frame dell'animazione | |
- Per risultati migliori, usa descrizioni dettagliate | |
Creato da [Lod34](https://huggingface.co/Lod34) | |
""" | |
) | |
# Avvio dell'interfaccia | |
interface.launch() | |
except Exception as e: | |
print(f"Errore nell'inizializzazione dell'applicazione: {str(e)}") | |
raise |