import gradio as gr import torch from PIL import Image import os from transformers import AutoTokenizer, AutoModel, T5ForConditionalGeneration from huggingface_hub import hf_hub_download import torch.nn as nn class SpriteGenerator(nn.Module): def __init__(self, text_encoder_name="t5-base", latent_dim=512): super(SpriteGenerator, self).__init__() # Text encoder (T5 with lm_head) self.text_encoder = T5ForConditionalGeneration.from_pretrained(text_encoder_name) for param in self.text_encoder.parameters(): param.requires_grad = False # Proiezione dal testo al latent space self.text_projection = nn.Sequential( nn.Linear(768, latent_dim), nn.LeakyReLU(0.2), nn.Linear(latent_dim, latent_dim) ) # Generator self.generator = nn.Sequential( # Input: latent_dim x 1 x 1 -> 512 x 4 x 4 nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), # 512 x 4 x 4 -> 256 x 8 x 8 nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), # 256 x 8 x 8 -> 128 x 16 x 16 nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), # 128 x 16 x 16 -> 64 x 32 x 32 nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True), # 64 x 32 x 32 -> 32 x 64 x 64 nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False), nn.BatchNorm2d(32), nn.ReLU(True), # 32 x 64 x 64 -> 16 x 128 x 128 nn.ConvTranspose2d(32, 16, 4, 2, 1, bias=False), nn.BatchNorm2d(16), nn.ReLU(True), # 16 x 128 x 128 -> 3 x 256 x 256 nn.ConvTranspose2d(16, 3, 4, 2, 1, bias=False), ) # Frame interpolator self.frame_interpolator = nn.Sequential( nn.Linear(latent_dim + 1, latent_dim), nn.LeakyReLU(0.2), nn.Linear(latent_dim, latent_dim), nn.LeakyReLU(0.2) ) def forward(self, input_ids, attention_mask, num_frames=1): batch_size = input_ids.shape[0] # Encode text usando il T5 completo text_outputs = self.text_encoder.encoder( input_ids=input_ids, attention_mask=attention_mask, return_dict=True ) # Get text features text_features = text_outputs.last_hidden_state.mean(dim=1) # Project to latent space latent_vector = self.text_projection(text_features) # Generate multiple frames if needed all_frames = [] for frame_idx in range(max(num_frames.max().item(), 1)): frame_info = torch.ones((batch_size, 1), device=latent_vector.device) * frame_idx / max(num_frames.max().item(), 1) # Combine latent vector with frame info frame_latent = self.frame_interpolator( torch.cat([latent_vector, frame_info], dim=1) ) # Generate frame frame_latent_reshaped = frame_latent.unsqueeze(2).unsqueeze(3) frame = self.generator(frame_latent_reshaped) frame = torch.tanh(frame) all_frames.append(frame) # Stack all frames sprites = torch.stack(all_frames, dim=1) return sprites def initialize_model(): print("Inizializzazione del modello...") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = SpriteGenerator() try: # Scarica il modello da Hugging Face Hub model_path = hf_hub_download( repo_id="Lod34/Animator2D-v2", filename="pytorch_model.bin", repo_type="model" ) # Carica il modello state_dict = torch.load(model_path, map_location=device) model.load_state_dict(state_dict) model = model.to(device) model.eval() print(f"Modello caricato con successo su {device}!") return model, device except Exception as e: print(f"Errore nel caricamento del modello: {str(e)}") raise def generate_sprite(prompt, num_frames=8): try: # Usa il modello e il device globali global model, device, tokenizer # Tokenizza il testo tokens = tokenizer(prompt, return_tensors="pt", padding=True) tokens = {k: v.to(device) for k, v in tokens.items()} # Genera l'immagine with torch.no_grad(): frames = model( input_ids=tokens["input_ids"], attention_mask=tokens["attention_mask"], num_frames=torch.tensor([num_frames], device=device) ) # Converte il tensore in immagine frames = (frames * 0.5 + 0.5).clamp(0, 1) frames = frames.cpu().numpy() # Ritorna il primo frame come esempio frame = frames[0, 0] # Prende il primo frame del batch frame = (frame * 255).astype('uint8').transpose(1, 2, 0) return Image.fromarray(frame) except Exception as e: print(f"Errore nella generazione: {str(e)}") raise # Inizializzazione globale print("Caricamento del modello e configurazione dell'interfaccia...") try: # Inizializzazione del modello e del tokenizer model, device = initialize_model() tokenizer = AutoTokenizer.from_pretrained("t5-base") # Configurazione dell'interfaccia Gradio interface = gr.Interface( fn=generate_sprite, inputs=[ gr.Textbox( label="Descrivi lo sprite che vuoi generare", placeholder="Esempio: un personaggio pixel art che cammina" ), gr.Slider( minimum=1, maximum=16, value=8, step=1, label="Numero di frame", info="Più frame = animazione più fluida ma generazione più lenta" ) ], outputs=gr.Image(label="Sprite generato"), title="🎮 Animator2D-v2 Sprite Generator", description=""" ## Generatore di Sprite Animati Questo strumento genera sprite pixel art da descrizioni testuali. ### Come usare: 1. Inserisci una descrizione dello sprite che vuoi generare 2. Regola il numero di frame desiderati 3. Clicca su Submit e attendi la generazione ### Tips: - Sii specifico nella descrizione - Prova diversi numeri di frame per risultati diversi - Le descrizioni in inglese potrebbero funzionare meglio """, article=""" ### Note: - La generazione può richiedere alcuni secondi - Vengono mostrati solo i primi frame dell'animazione - Per risultati migliori, usa descrizioni dettagliate Creato da [Lod34](https://huggingface.co/Lod34) """ ) # Avvio dell'interfaccia interface.launch() except Exception as e: print(f"Errore nell'inizializzazione dell'applicazione: {str(e)}") raise