Lod34 commited on
Commit
57ee356
·
verified ·
1 Parent(s): be2a526

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -64
app.py CHANGED
@@ -7,7 +7,101 @@ from huggingface_hub import hf_hub_download
7
  import torch.nn as nn
8
 
9
  class SpriteGenerator(nn.Module):
10
- # ... (la classe SpriteGenerator rimane invariata) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def initialize_model():
13
  print("Inizializzazione del modello...")
@@ -28,7 +122,7 @@ def initialize_model():
28
  model.load_state_dict(state_dict)
29
  model = model.to(device)
30
  model.eval()
31
- print("Modello caricato con successo da Hugging Face Hub!")
32
  return model, device
33
  except Exception as e:
34
  print(f"Errore nel caricamento del modello: {str(e)}")
@@ -65,80 +159,58 @@ def generate_sprite(prompt, num_frames=8):
65
  raise
66
 
67
  # Inizializzazione globale
68
- print("Caricamento del modello...")
69
  try:
 
70
  model, device = initialize_model()
71
  tokenizer = AutoTokenizer.from_pretrained("t5-base")
72
 
73
- # Creazione dell'interfaccia Gradio
74
  interface = gr.Interface(
75
  fn=generate_sprite,
76
  inputs=[
77
- gr.Textbox(label="Descrivi lo sprite che vuoi generare"),
78
- gr.Slider(minimum=1, maximum=16, value=8, step=1, label="Numero di frame")
 
 
 
 
 
 
 
 
 
 
79
  ],
80
  outputs=gr.Image(label="Sprite generato"),
81
- title="Animator2D-v2 Sprite Generator",
82
- description="Genera sprite animati da descrizioni testuali"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  )
84
 
85
  # Avvio dell'interfaccia
86
  interface.launch()
 
87
  except Exception as e:
88
  print(f"Errore nell'inizializzazione dell'applicazione: {str(e)}")
89
- raise e
90
-
91
- # Interfaccia Gradio
92
- def create_interface():
93
- with gr.Blocks(title="Animated Sprite Generator") as demo:
94
- gr.Markdown("# 🎮 AI Animated Sprite Generator")
95
- gr.Markdown("""
96
- Generate animated sprites using AI! Just describe your character and choose the animation settings.
97
- """)
98
-
99
- with gr.Row():
100
- with gr.Column():
101
- char_desc = gr.Textbox(
102
- label="Character Description",
103
- placeholder="Ex: a knight with golden armor and a fire sword",
104
- lines=3
105
- )
106
- num_frames = gr.Slider(
107
- minimum=1,
108
- maximum=8,
109
- step=1,
110
- value=4,
111
- label="Number of Animation Frames"
112
- )
113
- char_action = gr.Dropdown(
114
- choices=["idle", "walk", "run", "attack", "jump", "die", "cast spell", "dance"],
115
- label="Character Action",
116
- value="idle"
117
- )
118
- view_direction = gr.Dropdown(
119
- choices=["front", "back", "left", "right", "front-left", "front-right", "back-left", "back-right"],
120
- label="Viewing Direction",
121
- value="front"
122
- )
123
- generate_btn = gr.Button("Generate Animated Sprite")
124
-
125
- with gr.Column():
126
- animated_output = gr.Image(label="Animated Sprite (GIF)")
127
-
128
- generate_btn.click(
129
- fn=generate_animated_sprite,
130
- inputs=[char_desc, num_frames, char_action, view_direction],
131
- outputs=animated_output
132
- )
133
-
134
- gr.Examples([
135
- ["A wizard with blue cloak and pointed hat", 4, "cast spell", "front"],
136
- ["A warrior with heavy armor and axe", 6, "attack", "right"],
137
- ["A ninja with black clothes and throwing stars", 8, "run", "front-left"],
138
- ["A princess with golden crown and pink dress", 4, "dance", "front"]
139
- ], inputs=[char_desc, num_frames, char_action, view_direction])
140
-
141
- return demo
142
-
143
- # Crea l'interfaccia
144
- demo = create_interface()
 
7
  import torch.nn as nn
8
 
9
  class SpriteGenerator(nn.Module):
10
+ def __init__(self, text_encoder_name="t5-base", latent_dim=512):
11
+ super(SpriteGenerator, self).__init__()
12
+
13
+ # Text encoder (T5 with lm_head)
14
+ self.text_encoder = T5ForConditionalGeneration.from_pretrained(text_encoder_name)
15
+ for param in self.text_encoder.parameters():
16
+ param.requires_grad = False
17
+
18
+ # Proiezione dal testo al latent space
19
+ self.text_projection = nn.Sequential(
20
+ nn.Linear(768, latent_dim),
21
+ nn.LeakyReLU(0.2),
22
+ nn.Linear(latent_dim, latent_dim)
23
+ )
24
+
25
+ # Generator
26
+ self.generator = nn.Sequential(
27
+ # Input: latent_dim x 1 x 1 -> 512 x 4 x 4
28
+ nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
29
+ nn.BatchNorm2d(512),
30
+ nn.ReLU(True),
31
+
32
+ # 512 x 4 x 4 -> 256 x 8 x 8
33
+ nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
34
+ nn.BatchNorm2d(256),
35
+ nn.ReLU(True),
36
+
37
+ # 256 x 8 x 8 -> 128 x 16 x 16
38
+ nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
39
+ nn.BatchNorm2d(128),
40
+ nn.ReLU(True),
41
+
42
+ # 128 x 16 x 16 -> 64 x 32 x 32
43
+ nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
44
+ nn.BatchNorm2d(64),
45
+ nn.ReLU(True),
46
+
47
+ # 64 x 32 x 32 -> 32 x 64 x 64
48
+ nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
49
+ nn.BatchNorm2d(32),
50
+ nn.ReLU(True),
51
+
52
+ # 32 x 64 x 64 -> 16 x 128 x 128
53
+ nn.ConvTranspose2d(32, 16, 4, 2, 1, bias=False),
54
+ nn.BatchNorm2d(16),
55
+ nn.ReLU(True),
56
+
57
+ # 16 x 128 x 128 -> 3 x 256 x 256
58
+ nn.ConvTranspose2d(16, 3, 4, 2, 1, bias=False),
59
+ )
60
+
61
+ # Frame interpolator
62
+ self.frame_interpolator = nn.Sequential(
63
+ nn.Linear(latent_dim + 1, latent_dim),
64
+ nn.LeakyReLU(0.2),
65
+ nn.Linear(latent_dim, latent_dim),
66
+ nn.LeakyReLU(0.2)
67
+ )
68
+
69
+ def forward(self, input_ids, attention_mask, num_frames=1):
70
+ batch_size = input_ids.shape[0]
71
+
72
+ # Encode text usando il T5 completo
73
+ text_outputs = self.text_encoder.encoder(
74
+ input_ids=input_ids,
75
+ attention_mask=attention_mask,
76
+ return_dict=True
77
+ )
78
+
79
+ # Get text features
80
+ text_features = text_outputs.last_hidden_state.mean(dim=1)
81
+
82
+ # Project to latent space
83
+ latent_vector = self.text_projection(text_features)
84
+
85
+ # Generate multiple frames if needed
86
+ all_frames = []
87
+ for frame_idx in range(max(num_frames.max().item(), 1)):
88
+ frame_info = torch.ones((batch_size, 1), device=latent_vector.device) * frame_idx / max(num_frames.max().item(), 1)
89
+
90
+ # Combine latent vector with frame info
91
+ frame_latent = self.frame_interpolator(
92
+ torch.cat([latent_vector, frame_info], dim=1)
93
+ )
94
+
95
+ # Generate frame
96
+ frame_latent_reshaped = frame_latent.unsqueeze(2).unsqueeze(3)
97
+ frame = self.generator(frame_latent_reshaped)
98
+ frame = torch.tanh(frame)
99
+ all_frames.append(frame)
100
+
101
+ # Stack all frames
102
+ sprites = torch.stack(all_frames, dim=1)
103
+
104
+ return sprites
105
 
106
  def initialize_model():
107
  print("Inizializzazione del modello...")
 
122
  model.load_state_dict(state_dict)
123
  model = model.to(device)
124
  model.eval()
125
+ print(f"Modello caricato con successo su {device}!")
126
  return model, device
127
  except Exception as e:
128
  print(f"Errore nel caricamento del modello: {str(e)}")
 
159
  raise
160
 
161
  # Inizializzazione globale
162
+ print("Caricamento del modello e configurazione dell'interfaccia...")
163
  try:
164
+ # Inizializzazione del modello e del tokenizer
165
  model, device = initialize_model()
166
  tokenizer = AutoTokenizer.from_pretrained("t5-base")
167
 
168
+ # Configurazione dell'interfaccia Gradio
169
  interface = gr.Interface(
170
  fn=generate_sprite,
171
  inputs=[
172
+ gr.Textbox(
173
+ label="Descrivi lo sprite che vuoi generare",
174
+ placeholder="Esempio: un personaggio pixel art che cammina"
175
+ ),
176
+ gr.Slider(
177
+ minimum=1,
178
+ maximum=16,
179
+ value=8,
180
+ step=1,
181
+ label="Numero di frame",
182
+ info="Più frame = animazione più fluida ma generazione più lenta"
183
+ )
184
  ],
185
  outputs=gr.Image(label="Sprite generato"),
186
+ title="🎮 Animator2D-v2 Sprite Generator",
187
+ description="""
188
+ ## Generatore di Sprite Animati
189
+ Questo strumento genera sprite pixel art da descrizioni testuali.
190
+
191
+ ### Come usare:
192
+ 1. Inserisci una descrizione dello sprite che vuoi generare
193
+ 2. Regola il numero di frame desiderati
194
+ 3. Clicca su Submit e attendi la generazione
195
+
196
+ ### Tips:
197
+ - Sii specifico nella descrizione
198
+ - Prova diversi numeri di frame per risultati diversi
199
+ - Le descrizioni in inglese potrebbero funzionare meglio
200
+ """,
201
+ article="""
202
+ ### Note:
203
+ - La generazione può richiedere alcuni secondi
204
+ - Vengono mostrati solo i primi frame dell'animazione
205
+ - Per risultati migliori, usa descrizioni dettagliate
206
+
207
+ Creato da [Lod34](https://huggingface.co/Lod34)
208
+ """
209
  )
210
 
211
  # Avvio dell'interfaccia
212
  interface.launch()
213
+
214
  except Exception as e:
215
  print(f"Errore nell'inizializzazione dell'applicazione: {str(e)}")
216
+ raise