Lorenzo Adacher commited on
Commit
76e3656
·
verified ·
1 Parent(s): 9faa6d9

Upload 2 files

Browse files
Files changed (2) hide show
  1. gradio-interface.py +71 -0
  2. training-code.py +360 -0
gradio-interface.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ # Define the function to generate the sprite based on user input
4
+ def generate_sprite(character_description, num_frames, character_action, viewing_direction):
5
+ # Combine user inputs into a single prompt
6
+ prompt = f"Character description: {character_description}\n" \
7
+ f"Character action: {character_action}\n" \
8
+ f"Viewing direction: {viewing_direction}\n" \
9
+ f"Number of frames: {num_frames}"
10
+
11
+ # Load the model from Hugging Face Hub
12
+ model = gr.Interface.load("huggingface/Lod34/Animator2D-v2")
13
+
14
+ # Generate the sprite using the model
15
+ result = model(prompt)
16
+
17
+ return result
18
+
19
+ # Configure the Gradio interface
20
+ with gr.Blocks(title="Animated Sprite Generator") as demo:
21
+ gr.Markdown("# 🎮 AI Animated Sprite Generator")
22
+ gr.Markdown("""
23
+ This tool uses an AI model to generate animated sprites based on text descriptions.
24
+ Enter the character description, number of frames, character action, and viewing direction to generate your animated sprite.
25
+ """)
26
+
27
+ with gr.Row():
28
+ with gr.Column():
29
+ # Input components
30
+ char_desc = gr.Textbox(label="Character Description",
31
+ placeholder="Ex: a knight with golden armor and a fire sword",
32
+ lines=3)
33
+ num_frames = gr.Slider(minimum=1, maximum=8, step=1, value=4,
34
+ label="Number of Animation Frames")
35
+ char_action = gr.Dropdown(
36
+ choices=["idle", "walk", "run", "attack", "jump", "die", "cast spell", "dance"],
37
+ label="Character Action",
38
+ value="idle"
39
+ )
40
+ view_direction = gr.Dropdown(
41
+ choices=["front", "back", "left", "right", "front-left", "front-right", "back-left", "back-right"],
42
+ label="Viewing Direction",
43
+ value="front"
44
+ )
45
+ generate_btn = gr.Button("Generate Animated Sprite")
46
+
47
+ with gr.Column():
48
+ # Output component
49
+ animated_output = gr.Image(label="Animated Sprite (GIF)")
50
+
51
+ # Connect the button to the function
52
+ generate_btn.click(
53
+ fn=generate_sprite,
54
+ inputs=[char_desc, num_frames, char_action, view_direction],
55
+ outputs=animated_output
56
+ )
57
+
58
+ # Predefined examples
59
+ gr.Examples(
60
+ [
61
+ ["A wizard with blue cloak and pointed hat", 4, "cast spell", "front"],
62
+ ["A warrior with heavy armor and axe", 6, "attack", "right"],
63
+ ["A ninja with black clothes and throwing stars", 8, "run", "front-left"],
64
+ ["A princess with golden crown and pink dress", 4, "dance", "front"]
65
+ ],
66
+ inputs=[char_desc, num_frames, char_action, view_direction]
67
+ )
68
+
69
+ # Launch the Gradio interface
70
+ if __name__ == "__main__":
71
+ demo.launch(share=True)
training-code.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.utils.data import DataLoader, Dataset, random_split
6
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
+ from datasets import load_dataset
8
+ from PIL import Image
9
+ import numpy as np
10
+ from torchvision import transforms
11
+ import matplotlib.pyplot as plt
12
+ from tqdm import tqdm
13
+ import io
14
+
15
+ # Definiamo un percorso per salvare il modello addestrato
16
+ MODEL_PATH = "sprite_generator_model"
17
+ os.makedirs(MODEL_PATH, exist_ok=True)
18
+
19
+ # Carichiamo il dataset da Hugging Face
20
+ print("Caricamento del dataset...")
21
+ dataset = load_dataset("pawkanarek/spraix_1024")
22
+ print(f"Dataset caricato. Dimensioni: {len(dataset['train'])} esempi di training")
23
+
24
+ # Verifichiamo gli split disponibili
25
+ print("Split disponibili nel dataset:")
26
+ print(dataset.keys())
27
+
28
+ # Stampiamo un esempio per capire la struttura del dataset
29
+ print("Esempio di dato dal dataset:")
30
+ example = dataset['train'][0]
31
+ print("Chiavi disponibili:", example.keys())
32
+ for key in example:
33
+ print(f"{key}: {type(example[key])}")
34
+ # Se il valore è un dizionario, stampiamo anche le sue chiavi
35
+ if isinstance(example[key], dict):
36
+ print(f" Sottochavi: {example[key].keys()}")
37
+
38
+ # Classe per il nostro dataset personalizzato
39
+ class SpriteDataset(Dataset):
40
+ def __init__(self, dataset_to_use, max_length=128):
41
+ self.dataset = dataset_to_use
42
+ self.tokenizer = AutoTokenizer.from_pretrained("t5-base")
43
+ self.max_length = max_length
44
+ self.transform = transforms.Compose([
45
+ transforms.Resize((256, 256)),
46
+ transforms.ToTensor(),
47
+ transforms.ConvertImageDtype(torch.float), # Converti in float32
48
+ transforms.Lambda(lambda image: image[:3, :, :]), # Seleziona solo i primi 3 canali (RGB)
49
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
50
+ ])
51
+
52
+ def __len__(self):
53
+ return len(self.dataset)
54
+
55
+ def __getitem__(self, idx):
56
+ item = self.dataset[idx]
57
+
58
+ # Estrai informazioni dalla descrizione completa
59
+ description = item['text'] if 'text' in item else ""
60
+
61
+ # Estrai numero di frame dal testo
62
+ num_frames = 1 # valore di default
63
+ if "frame" in description:
64
+ # Cerca numeri seguiti da "frame" nel testo
65
+ import re
66
+ frames_match = re.search(r'(\d+)-frame', description)
67
+ if frames_match:
68
+ num_frames = int(frames_match.group(1))
69
+
70
+ # Prepara il testo per il modello
71
+ text_input = f"""
72
+ Description: {description}
73
+ Number of frames: {num_frames}
74
+ """
75
+
76
+ # Tokenizziamo l'input testuale
77
+ encoded_text = self.tokenizer(
78
+ text_input,
79
+ padding="max_length",
80
+ max_length=self.max_length,
81
+ truncation=True,
82
+ return_tensors="pt"
83
+ )
84
+
85
+ # Prepariamo l'immagine (o le immagini se ci sono frame multipli)
86
+ sprite_frames = []
87
+
88
+ # Controlla le chiavi disponibili per i frame
89
+ if 'image' in item:
90
+ # Se c'è un'unica immagine
91
+ img = item['image']
92
+ if isinstance(img, dict) and 'bytes' in img:
93
+ img_pil = Image.open(io.BytesIO(img['bytes']))
94
+ sprite_frames.append(self.transform(img_pil))
95
+ elif hasattr(img, 'convert'): # Se è già un'immagine PIL
96
+ sprite_frames.append(self.transform(img))
97
+ else:
98
+ # Prova a cercare frame_0, frame_1, ecc.
99
+ for frame in range(num_frames):
100
+ frame_key = f'frame_{frame}'
101
+ if frame_key in item:
102
+ img = item[frame_key]
103
+ if isinstance(img, dict) and 'bytes' in img:
104
+ img_pil = Image.open(io.BytesIO(img['bytes']))
105
+ sprite_frames.append(self.transform(img_pil))
106
+ elif hasattr(img, 'convert'): # Se è già un'immagine PIL
107
+ sprite_frames.append(self.transform(img))
108
+
109
+ # Se non abbiamo trovato immagini, prova a cercare altre chiavi comuni
110
+ if not sprite_frames:
111
+ possible_image_keys = ['image', 'img', 'sprite', 'frames']
112
+ for key in possible_image_keys:
113
+ if key in item and item[key] is not None:
114
+ img = item[key]
115
+ if isinstance(img, dict) and 'bytes' in img:
116
+ img_pil = Image.open(io.BytesIO(img['bytes']))
117
+ sprite_frames.append(self.transform(img_pil))
118
+ elif hasattr(img, 'convert'): # Se è già un'immagine PIL
119
+ sprite_frames.append(self.transform(img))
120
+ break
121
+
122
+ # Se ancora non abbiamo frame, crea un tensore vuoto
123
+ if not sprite_frames:
124
+ sprite_frames.append(torch.zeros((3, 256, 256)))
125
+
126
+ # Combiniamo tutti i frame in un unico tensore
127
+ sprite_tensor = torch.stack(sprite_frames)
128
+
129
+ return {
130
+ "input_ids": encoded_text.input_ids.squeeze(),
131
+ "attention_mask": encoded_text.attention_mask.squeeze(),
132
+ "sprite_frames": sprite_tensor,
133
+ "num_frames": torch.tensor(num_frames, dtype=torch.int64)
134
+ }
135
+
136
+ # Modello generatore di sprite
137
+ class SpriteGenerator(nn.Module):
138
+ def __init__(self, text_encoder_name="t5-base", latent_dim=512):
139
+ super(SpriteGenerator, self).__init__()
140
+
141
+ # Encoder testuale
142
+ self.text_encoder = AutoModelForSeq2SeqLM.from_pretrained(text_encoder_name)
143
+ # Freeziamo i parametri dell'encoder per iniziare
144
+ for param in self.text_encoder.parameters():
145
+ param.requires_grad = False
146
+
147
+ # Proiezione dal testo al latent space
148
+ self.text_projection = nn.Sequential(
149
+ nn.Linear(self.text_encoder.config.d_model, latent_dim),
150
+ nn.LeakyReLU(0.2),
151
+ nn.Linear(latent_dim, latent_dim)
152
+ )
153
+
154
+ # Frame generator (una rete deconvoluzionale)
155
+ self.generator = nn.Sequential(
156
+ # Input: latent_dim x 1 x 1
157
+ nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False), # -> 512 x 4 x 4
158
+ nn.BatchNorm2d(512),
159
+ nn.ReLU(True),
160
+
161
+ nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), # -> 256 x 8 x 8
162
+ nn.BatchNorm2d(256),
163
+ nn.ReLU(True),
164
+
165
+ nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), # -> 128 x 16 x 16
166
+ nn.BatchNorm2d(128),
167
+ nn.ReLU(True),
168
+
169
+ nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False), # -> 64 x 32 x 32
170
+ nn.BatchNorm2d(64),
171
+ nn.ReLU(True),
172
+
173
+ nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False), # -> 32 x 64 x 64
174
+ nn.BatchNorm2d(32),
175
+ nn.ReLU(True),
176
+
177
+ nn.ConvTranspose2d(32, 16, 4, 2, 1, bias=False), # -> 16 x 128 x 128
178
+ nn.BatchNorm2d(16),
179
+ nn.ReLU(True),
180
+
181
+ nn.ConvTranspose2d(16, 3, 4, 2, 1, bias=False), # -> 3 x 256 x 256
182
+ nn.Tanh()
183
+ )
184
+
185
+ # Frame interpolator per supportare animazioni con più frame
186
+ self.frame_interpolator = nn.Sequential(
187
+ nn.Linear(latent_dim + 1, latent_dim), # +1 per l'informazione sul frame
188
+ nn.LeakyReLU(0.2),
189
+ nn.Linear(latent_dim, latent_dim),
190
+ nn.LeakyReLU(0.2)
191
+ )
192
+
193
+ def forward(self, input_ids, attention_mask, num_frames=1):
194
+ batch_size = input_ids.shape[0]
195
+
196
+ # Codifichiamo il testo
197
+ text_outputs = self.text_encoder.encoder(
198
+ input_ids=input_ids,
199
+ attention_mask=attention_mask,
200
+ return_dict=True
201
+ )
202
+
203
+ # Utilizziamo l'ultimo hidden state
204
+ text_features = text_outputs.last_hidden_state.mean(dim=1) # Media per ottenere un vettore per esempio
205
+
206
+ # Proiettiamo nello spazio latente
207
+ latent_vector = self.text_projection(text_features)
208
+
209
+ # Generiamo frame multipli se necessario
210
+ all_frames = []
211
+ for frame_idx in range(max(num_frames.max().item(), 1)):
212
+ # Normalizziamo l'indice del frame
213
+ frame_info = torch.ones((batch_size, 1), device=latent_vector.device) * frame_idx / max(num_frames.max().item(), 1)
214
+
215
+ # Combiniamo il vettore latente con l'informazione sul frame
216
+ frame_latent = self.frame_interpolator(
217
+ torch.cat([latent_vector, frame_info], dim=1)
218
+ )
219
+
220
+ # Ricordiamo quanti frame generare per ogni esempio del batch
221
+ frame_mask = (frame_idx < num_frames).float().unsqueeze(1)
222
+
223
+ # Riformattiamo per il generatore
224
+ frame_latent_reshaped = frame_latent.unsqueeze(2).unsqueeze(3) # [B, latent_dim, 1, 1]
225
+
226
+ # Generiamo il frame
227
+ frame = self.generator(frame_latent_reshaped) * frame_mask.unsqueeze(2).unsqueeze(3)
228
+ all_frames.append(frame)
229
+
230
+ # Combiniamo tutti i frame
231
+ sprites = torch.stack(all_frames, dim=1) # [B, num_frames, 3, 256, 256]
232
+
233
+ return sprites
234
+
235
+ # Funzione per addestrare il modello
236
+ def train_model(model, train_loader, val_loader, epochs=10, lr=0.0002):
237
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
238
+ print(f"Utilizzo del dispositivo: {device}")
239
+
240
+ model = model.to(device)
241
+
242
+ optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999))
243
+ criterion = nn.MSELoss()
244
+
245
+ best_val_loss = float('inf')
246
+
247
+ for epoch in range(epochs):
248
+ # Training
249
+ model.train()
250
+ train_loss = 0.0
251
+
252
+ for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} - Training"):
253
+ input_ids = batch["input_ids"].to(device)
254
+ attention_mask = batch["attention_mask"].to(device)
255
+ target_sprites = batch["sprite_frames"].to(device)
256
+ num_frames = batch["num_frames"].to(device)
257
+
258
+ optimizer.zero_grad()
259
+
260
+ # Forward pass
261
+ output_sprites = model(input_ids, attention_mask, num_frames)
262
+
263
+ # Calcoliamo la loss per il batch
264
+ loss = 0.0
265
+ for i in range(len(num_frames)):
266
+ # Utilizziamo solo i frame validi per ogni esempio
267
+ valid_frames = min(output_sprites.shape[1], target_sprites.shape[1], num_frames[i].item())
268
+ if valid_frames > 0:
269
+ loss += criterion(
270
+ output_sprites[i, :valid_frames],
271
+ target_sprites[i, :valid_frames]
272
+ )
273
+
274
+ loss = loss / len(num_frames) # Media per batch
275
+
276
+ # Backward pass
277
+ loss.backward()
278
+ optimizer.step()
279
+
280
+ train_loss += loss.item()
281
+
282
+ train_loss /= len(train_loader)
283
+
284
+ # Validation
285
+ model.eval()
286
+ val_loss = 0.0
287
+
288
+ with torch.no_grad():
289
+ for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} - Validation"):
290
+ input_ids = batch["input_ids"].to(device)
291
+ attention_mask = batch["attention_mask"].to(device)
292
+ target_sprites = batch["sprite_frames"].to(device)
293
+ num_frames = batch["num_frames"].to(device)
294
+
295
+ output_sprites = model(input_ids, attention_mask, num_frames)
296
+
297
+ # Calcoliamo la loss per il batch di validazione
298
+ loss = 0.0
299
+ for i in range(len(num_frames)):
300
+ valid_frames = min(output_sprites.shape[1], target_sprites.shape[1], num_frames[i].item())
301
+ if valid_frames > 0:
302
+ loss += criterion(
303
+ output_sprites[i, :valid_frames],
304
+ target_sprites[i, :valid_frames]
305
+ )
306
+
307
+ loss = loss / len(num_frames)
308
+ val_loss += loss.item()
309
+
310
+ val_loss /= len(val_loader)
311
+
312
+ print(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
313
+
314
+ # Salviamo il modello migliore
315
+ if val_loss < best_val_loss:
316
+ best_val_loss = val_loss
317
+ torch.save(model.state_dict(), os.path.join(MODEL_PATH, "best_model.pth"))
318
+ print(f"Modello salvato con Val Loss: {val_loss:.4f}")
319
+
320
+ # Salviamo il modello finale
321
+ torch.save(model.state_dict(), os.path.join(MODEL_PATH, "Animator2D-v2.pth"))
322
+ print(f"Addestramento completato. Modello finale salvato.")
323
+
324
+ return model
325
+
326
+ # Codice per l'esecuzione dell'addestramento
327
+ if __name__ == "__main__":
328
+ # Dividiamo il dataset in train e validation manualmente
329
+ # dato che abbiamo solo lo split "train"
330
+ train_size = int(0.8 * len(dataset['train'])) # 80% per training
331
+ val_size = len(dataset['train']) - train_size # 20% per validation
332
+
333
+ print(f"Dividendo il dataset: {train_size} esempi per training, {val_size} esempi per validation")
334
+
335
+ # Creiamo i subset
336
+ train_subset, val_subset = random_split(
337
+ dataset['train'],
338
+ [train_size, val_size]
339
+ )
340
+
341
+ # Creiamo i dataset personalizzati
342
+ train_dataset = SpriteDataset(train_subset)
343
+ val_dataset = SpriteDataset(val_subset)
344
+
345
+ print(f"Dataset creati: {len(train_dataset)} esempi di training, {len(val_dataset)} esempi di validation")
346
+
347
+ # Creiamo i dataloader
348
+ train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
349
+ val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4)
350
+
351
+ # Creiamo e addestriamo il modello
352
+ model = SpriteGenerator()
353
+ trained_model = train_model(
354
+ model,
355
+ train_loader,
356
+ val_loader,
357
+ epochs=20
358
+ )
359
+
360
+ print("Modello addestrato con successo!")