import os import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset, random_split from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from datasets import load_dataset from PIL import Image import numpy as np from torchvision import transforms import matplotlib.pyplot as plt from tqdm import tqdm import io # Definiamo un percorso per salvare il modello addestrato MODEL_PATH = "sprite_generator_model" os.makedirs(MODEL_PATH, exist_ok=True) # Carichiamo il dataset da Hugging Face print("Caricamento del dataset...") dataset = load_dataset("pawkanarek/spraix_1024") print(f"Dataset caricato. Dimensioni: {len(dataset['train'])} esempi di training") # Verifichiamo gli split disponibili print("Split disponibili nel dataset:") print(dataset.keys()) # Stampiamo un esempio per capire la struttura del dataset print("Esempio di dato dal dataset:") example = dataset['train'][0] print("Chiavi disponibili:", example.keys()) for key in example: print(f"{key}: {type(example[key])}") # Se il valore è un dizionario, stampiamo anche le sue chiavi if isinstance(example[key], dict): print(f" Sottochavi: {example[key].keys()}") # Classe per il nostro dataset personalizzato class SpriteDataset(Dataset): def __init__(self, dataset_to_use, max_length=128): self.dataset = dataset_to_use self.tokenizer = AutoTokenizer.from_pretrained("t5-base") self.max_length = max_length self.transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.ConvertImageDtype(torch.float), # Converti in float32 transforms.Lambda(lambda image: image[:3, :, :]), # Seleziona solo i primi 3 canali (RGB) transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) def __len__(self): return len(self.dataset) def __getitem__(self, idx): item = self.dataset[idx] # Estrai informazioni dalla descrizione completa description = item['text'] if 'text' in item else "" # Estrai numero di frame dal testo num_frames = 1 # valore di default if "frame" in description: # Cerca numeri seguiti da "frame" nel testo import re frames_match = re.search(r'(\d+)-frame', description) if frames_match: num_frames = int(frames_match.group(1)) # Prepara il testo per il modello text_input = f""" Description: {description} Number of frames: {num_frames} """ # Tokenizziamo l'input testuale encoded_text = self.tokenizer( text_input, padding="max_length", max_length=self.max_length, truncation=True, return_tensors="pt" ) # Prepariamo l'immagine (o le immagini se ci sono frame multipli) sprite_frames = [] # Controlla le chiavi disponibili per i frame if 'image' in item: # Se c'è un'unica immagine img = item['image'] if isinstance(img, dict) and 'bytes' in img: img_pil = Image.open(io.BytesIO(img['bytes'])) sprite_frames.append(self.transform(img_pil)) elif hasattr(img, 'convert'): # Se è già un'immagine PIL sprite_frames.append(self.transform(img)) else: # Prova a cercare frame_0, frame_1, ecc. for frame in range(num_frames): frame_key = f'frame_{frame}' if frame_key in item: img = item[frame_key] if isinstance(img, dict) and 'bytes' in img: img_pil = Image.open(io.BytesIO(img['bytes'])) sprite_frames.append(self.transform(img_pil)) elif hasattr(img, 'convert'): # Se è già un'immagine PIL sprite_frames.append(self.transform(img)) # Se non abbiamo trovato immagini, prova a cercare altre chiavi comuni if not sprite_frames: possible_image_keys = ['image', 'img', 'sprite', 'frames'] for key in possible_image_keys: if key in item and item[key] is not None: img = item[key] if isinstance(img, dict) and 'bytes' in img: img_pil = Image.open(io.BytesIO(img['bytes'])) sprite_frames.append(self.transform(img_pil)) elif hasattr(img, 'convert'): # Se è già un'immagine PIL sprite_frames.append(self.transform(img)) break # Se ancora non abbiamo frame, crea un tensore vuoto if not sprite_frames: sprite_frames.append(torch.zeros((3, 256, 256))) # Combiniamo tutti i frame in un unico tensore sprite_tensor = torch.stack(sprite_frames) return { "input_ids": encoded_text.input_ids.squeeze(), "attention_mask": encoded_text.attention_mask.squeeze(), "sprite_frames": sprite_tensor, "num_frames": torch.tensor(num_frames, dtype=torch.int64) } # Modello generatore di sprite class SpriteGenerator(nn.Module): def __init__(self, text_encoder_name="t5-base", latent_dim=512): super(SpriteGenerator, self).__init__() # Encoder testuale self.text_encoder = AutoModelForSeq2SeqLM.from_pretrained(text_encoder_name) # Freeziamo i parametri dell'encoder per iniziare for param in self.text_encoder.parameters(): param.requires_grad = False # Proiezione dal testo al latent space self.text_projection = nn.Sequential( nn.Linear(self.text_encoder.config.d_model, latent_dim), nn.LeakyReLU(0.2), nn.Linear(latent_dim, latent_dim) ) # Frame generator (una rete deconvoluzionale) self.generator = nn.Sequential( # Input: latent_dim x 1 x 1 nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False), # -> 512 x 4 x 4 nn.BatchNorm2d(512), nn.ReLU(True), nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), # -> 256 x 8 x 8 nn.BatchNorm2d(256), nn.ReLU(True), nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), # -> 128 x 16 x 16 nn.BatchNorm2d(128), nn.ReLU(True), nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False), # -> 64 x 32 x 32 nn.BatchNorm2d(64), nn.ReLU(True), nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False), # -> 32 x 64 x 64 nn.BatchNorm2d(32), nn.ReLU(True), nn.ConvTranspose2d(32, 16, 4, 2, 1, bias=False), # -> 16 x 128 x 128 nn.BatchNorm2d(16), nn.ReLU(True), nn.ConvTranspose2d(16, 3, 4, 2, 1, bias=False), # -> 3 x 256 x 256 nn.Tanh() ) # Frame interpolator per supportare animazioni con più frame self.frame_interpolator = nn.Sequential( nn.Linear(latent_dim + 1, latent_dim), # +1 per l'informazione sul frame nn.LeakyReLU(0.2), nn.Linear(latent_dim, latent_dim), nn.LeakyReLU(0.2) ) def forward(self, input_ids, attention_mask, num_frames=1): batch_size = input_ids.shape[0] # Codifichiamo il testo text_outputs = self.text_encoder.encoder( input_ids=input_ids, attention_mask=attention_mask, return_dict=True ) # Utilizziamo l'ultimo hidden state text_features = text_outputs.last_hidden_state.mean(dim=1) # Media per ottenere un vettore per esempio # Proiettiamo nello spazio latente latent_vector = self.text_projection(text_features) # Generiamo frame multipli se necessario all_frames = [] for frame_idx in range(max(num_frames.max().item(), 1)): # Normalizziamo l'indice del frame frame_info = torch.ones((batch_size, 1), device=latent_vector.device) * frame_idx / max(num_frames.max().item(), 1) # Combiniamo il vettore latente con l'informazione sul frame frame_latent = self.frame_interpolator( torch.cat([latent_vector, frame_info], dim=1) ) # Ricordiamo quanti frame generare per ogni esempio del batch frame_mask = (frame_idx < num_frames).float().unsqueeze(1) # Riformattiamo per il generatore frame_latent_reshaped = frame_latent.unsqueeze(2).unsqueeze(3) # [B, latent_dim, 1, 1] # Generiamo il frame frame = self.generator(frame_latent_reshaped) * frame_mask.unsqueeze(2).unsqueeze(3) all_frames.append(frame) # Combiniamo tutti i frame sprites = torch.stack(all_frames, dim=1) # [B, num_frames, 3, 256, 256] return sprites # Funzione per addestrare il modello def train_model(model, train_loader, val_loader, epochs=10, lr=0.0002): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Utilizzo del dispositivo: {device}") model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999)) criterion = nn.MSELoss() best_val_loss = float('inf') for epoch in range(epochs): # Training model.train() train_loss = 0.0 for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} - Training"): input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) target_sprites = batch["sprite_frames"].to(device) num_frames = batch["num_frames"].to(device) optimizer.zero_grad() # Forward pass output_sprites = model(input_ids, attention_mask, num_frames) # Calcoliamo la loss per il batch loss = 0.0 for i in range(len(num_frames)): # Utilizziamo solo i frame validi per ogni esempio valid_frames = min(output_sprites.shape[1], target_sprites.shape[1], num_frames[i].item()) if valid_frames > 0: loss += criterion( output_sprites[i, :valid_frames], target_sprites[i, :valid_frames] ) loss = loss / len(num_frames) # Media per batch # Backward pass loss.backward() optimizer.step() train_loss += loss.item() train_loss /= len(train_loader) # Validation model.eval() val_loss = 0.0 with torch.no_grad(): for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} - Validation"): input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) target_sprites = batch["sprite_frames"].to(device) num_frames = batch["num_frames"].to(device) output_sprites = model(input_ids, attention_mask, num_frames) # Calcoliamo la loss per il batch di validazione loss = 0.0 for i in range(len(num_frames)): valid_frames = min(output_sprites.shape[1], target_sprites.shape[1], num_frames[i].item()) if valid_frames > 0: loss += criterion( output_sprites[i, :valid_frames], target_sprites[i, :valid_frames] ) loss = loss / len(num_frames) val_loss += loss.item() val_loss /= len(val_loader) print(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}") # Salviamo il modello migliore if val_loss < best_val_loss: best_val_loss = val_loss torch.save(model.state_dict(), os.path.join(MODEL_PATH, "best_model.pth")) print(f"Modello salvato con Val Loss: {val_loss:.4f}") # Salviamo il modello finale torch.save(model.state_dict(), os.path.join(MODEL_PATH, "Animator2D-v2.pth")) print(f"Addestramento completato. Modello finale salvato.") return model # Codice per l'esecuzione dell'addestramento if __name__ == "__main__": # Dividiamo il dataset in train e validation manualmente # dato che abbiamo solo lo split "train" train_size = int(0.8 * len(dataset['train'])) # 80% per training val_size = len(dataset['train']) - train_size # 20% per validation print(f"Dividendo il dataset: {train_size} esempi per training, {val_size} esempi per validation") # Creiamo i subset train_subset, val_subset = random_split( dataset['train'], [train_size, val_size] ) # Creiamo i dataset personalizzati train_dataset = SpriteDataset(train_subset) val_dataset = SpriteDataset(val_subset) print(f"Dataset creati: {len(train_dataset)} esempi di training, {len(val_dataset)} esempi di validation") # Creiamo i dataloader train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4) # Creiamo e addestriamo il modello model = SpriteGenerator() trained_model = train_model( model, train_loader, val_loader, epochs=20 ) print("Modello addestrato con successo!")