Lod34 commited on
Commit
be2a526
·
verified ·
1 Parent(s): 95b77dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -97
app.py CHANGED
@@ -7,101 +7,7 @@ from huggingface_hub import hf_hub_download
7
  import torch.nn as nn
8
 
9
  class SpriteGenerator(nn.Module):
10
- def __init__(self, text_encoder_name="t5-base", latent_dim=512):
11
- super(SpriteGenerator, self).__init__()
12
-
13
- # Text encoder (T5 with lm_head)
14
- self.text_encoder = T5ForConditionalGeneration.from_pretrained(text_encoder_name)
15
- for param in self.text_encoder.parameters():
16
- param.requires_grad = False
17
-
18
- # Proiezione dal testo al latent space
19
- self.text_projection = nn.Sequential(
20
- nn.Linear(768, latent_dim),
21
- nn.LeakyReLU(0.2),
22
- nn.Linear(latent_dim, latent_dim)
23
- )
24
-
25
- # Generator
26
- self.generator = nn.Sequential(
27
- # Input: latent_dim x 1 x 1 -> 512 x 4 x 4
28
- nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
29
- nn.BatchNorm2d(512),
30
- nn.ReLU(True),
31
-
32
- # 512 x 4 x 4 -> 256 x 8 x 8
33
- nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
34
- nn.BatchNorm2d(256),
35
- nn.ReLU(True),
36
-
37
- # 256 x 8 x 8 -> 128 x 16 x 16
38
- nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
39
- nn.BatchNorm2d(128),
40
- nn.ReLU(True),
41
-
42
- # 128 x 16 x 16 -> 64 x 32 x 32
43
- nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
44
- nn.BatchNorm2d(64),
45
- nn.ReLU(True),
46
-
47
- # 64 x 32 x 32 -> 32 x 64 x 64
48
- nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
49
- nn.BatchNorm2d(32),
50
- nn.ReLU(True),
51
-
52
- # 32 x 64 x 64 -> 16 x 128 x 128
53
- nn.ConvTranspose2d(32, 16, 4, 2, 1, bias=False),
54
- nn.BatchNorm2d(16),
55
- nn.ReLU(True),
56
-
57
- # 16 x 128 x 128 -> 3 x 256 x 256
58
- nn.ConvTranspose2d(16, 3, 4, 2, 1, bias=False),
59
- )
60
-
61
- # Frame interpolator
62
- self.frame_interpolator = nn.Sequential(
63
- nn.Linear(latent_dim + 1, latent_dim),
64
- nn.LeakyReLU(0.2),
65
- nn.Linear(latent_dim, latent_dim),
66
- nn.LeakyReLU(0.2)
67
- )
68
-
69
- def forward(self, input_ids, attention_mask, num_frames=1):
70
- batch_size = input_ids.shape[0]
71
-
72
- # Encode text usando il T5 completo
73
- text_outputs = self.text_encoder.encoder(
74
- input_ids=input_ids,
75
- attention_mask=attention_mask,
76
- return_dict=True
77
- )
78
-
79
- # Get text features
80
- text_features = text_outputs.last_hidden_state.mean(dim=1)
81
-
82
- # Project to latent space
83
- latent_vector = self.text_projection(text_features)
84
-
85
- # Generate multiple frames if needed
86
- all_frames = []
87
- for frame_idx in range(max(num_frames.max().item(), 1)):
88
- frame_info = torch.ones((batch_size, 1), device=latent_vector.device) * frame_idx / max(num_frames.max().item(), 1)
89
-
90
- # Combine latent vector with frame info
91
- frame_latent = self.frame_interpolator(
92
- torch.cat([latent_vector, frame_info], dim=1)
93
- )
94
-
95
- # Generate frame
96
- frame_latent_reshaped = frame_latent.unsqueeze(2).unsqueeze(3)
97
- frame = self.generator(frame_latent_reshaped)
98
- frame = torch.tanh(frame)
99
- all_frames.append(frame)
100
-
101
- # Stack all frames
102
- sprites = torch.stack(all_frames, dim=1)
103
-
104
- return sprites
105
 
106
  def initialize_model():
107
  print("Inizializzazione del modello...")
@@ -110,12 +16,19 @@ def initialize_model():
110
  model = SpriteGenerator()
111
 
112
  try:
 
 
 
 
 
 
 
113
  # Carica il modello
114
- state_dict = torch.load("Animator2D-v2.pth", map_location=device)
115
  model.load_state_dict(state_dict)
116
  model = model.to(device)
117
  model.eval()
118
- print("Modello caricato con successo!")
119
  return model, device
120
  except Exception as e:
121
  print(f"Errore nel caricamento del modello: {str(e)}")
 
7
  import torch.nn as nn
8
 
9
  class SpriteGenerator(nn.Module):
10
+ # ... (la classe SpriteGenerator rimane invariata) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def initialize_model():
13
  print("Inizializzazione del modello...")
 
16
  model = SpriteGenerator()
17
 
18
  try:
19
+ # Scarica il modello da Hugging Face Hub
20
+ model_path = hf_hub_download(
21
+ repo_id="Lod34/Animator2D-v2",
22
+ filename="pytorch_model.bin",
23
+ repo_type="model"
24
+ )
25
+
26
  # Carica il modello
27
+ state_dict = torch.load(model_path, map_location=device)
28
  model.load_state_dict(state_dict)
29
  model = model.to(device)
30
  model.eval()
31
+ print("Modello caricato con successo da Hugging Face Hub!")
32
  return model, device
33
  except Exception as e:
34
  print(f"Errore nel caricamento del modello: {str(e)}")