Lod_34 commited on
Commit
98ebcac
·
unverified ·
1 Parent(s): b26d5b2

Add files via upload

Browse files
Files changed (3) hide show
  1. dataset-visualizer.py +8 -0
  2. gradio-interface.py +42 -0
  3. training-code.py +189 -0
dataset-visualizer.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+
3
+ # Carica il dataset
4
+ ds = load_dataset("pawkanarek/spraix_1024")
5
+
6
+ # Stampa le prime voci
7
+ print(ds)
8
+ print(ds["train"][0])
gradio-interface.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import AutoTokenizer
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ model = Animator2D().to(device)
9
+ model.load_state_dict(torch.load("animator2D-model.pth", map_location=device))
10
+ model.eval()
11
+
12
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
13
+
14
+ def generate_sprite(num_frames, description, action, direction):
15
+ text = f"{num_frames}-frame sprite animation of: {description}, that: {action}, facing: {direction}"
16
+ encoded_text = tokenizer(
17
+ text, padding="max_length", max_length=128, truncation=True, return_tensors="pt"
18
+ )
19
+
20
+ with torch.no_grad():
21
+ text_ids = encoded_text['input_ids'].to(device)
22
+ text_mask = encoded_text['attention_mask'].to(device)
23
+ generated_sprite = model(text_ids, text_mask).cpu().squeeze(0)
24
+
25
+ generated_sprite = (generated_sprite + 1) / 2 # Denormalizzazione
26
+ generated_sprite = transforms.ToPILImage()(generated_sprite)
27
+ return generated_sprite
28
+
29
+ iface = gr.Interface(
30
+ fn=generate_sprite,
31
+ inputs=[
32
+ gr.Number(label="Numero di Frame", value=17),
33
+ gr.Textbox(label="Descrizione dello Sprite"),
34
+ gr.Dropdown(["cammina", "corre", "salta", "attacca"], label="Azione"),
35
+ gr.Dropdown(["Nord", "Sud", "Est", "Ovest"], label="Direzione")
36
+ ],
37
+ outputs=gr.Image(type="pil"),
38
+ title="Animator2D Generator",
39
+ description="Genera animazioni di sprite basate su descrizioni testuali."
40
+ )
41
+
42
+ iface.launch(share=False) # Disabilita la condivisione pubblica
training-code.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from transformers import AutoTokenizer, AutoModel
6
+ from datasets import load_dataset
7
+ import numpy as np
8
+ from PIL import Image
9
+ import torchvision.transforms as transforms
10
+
11
+ class SpriteDataset(Dataset):
12
+ def __init__(self, dataset_split="train"):
13
+ # Load the dataset from HuggingFace
14
+ self.dataset = load_dataset("pawkanarek/spraix_1024", split=dataset_split)
15
+ self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
16
+
17
+ # Define image transforms
18
+ self.transform = transforms.Compose([
19
+ transforms.Resize((64, 64)), # Resize all sprites to same size
20
+ transforms.ToTensor(),
21
+ transforms.Normalize((0.5,), (0.5,))
22
+ ])
23
+
24
+ def __len__(self):
25
+ return len(self.dataset)
26
+
27
+ def __getitem__(self, idx):
28
+ item = self.dataset[idx]
29
+
30
+ # Process text description
31
+ text = f"{item['text']}" # Contains frames, description, action, direction
32
+ encoded_text = self.tokenizer(
33
+ text,
34
+ padding="max_length",
35
+ max_length=128,
36
+ truncation=True,
37
+ return_tensors="pt"
38
+ )
39
+
40
+ # Process image
41
+ # The item['image'] is already a PIL Image. Convert it to RGB if it's not already
42
+ image = item['image'].convert('RGB')
43
+ # Removed Image.fromarray as it's unnecessary
44
+ image_tensor = self.transform(image)
45
+
46
+ return {
47
+ 'text_ids': encoded_text['input_ids'].squeeze(),
48
+ 'text_mask': encoded_text['attention_mask'].squeeze(),
49
+ 'image': image_tensor
50
+ }
51
+
52
+ class TextEncoder(nn.Module):
53
+ def __init__(self):
54
+ super().__init__()
55
+ self.bert = AutoModel.from_pretrained("bert-base-uncased")
56
+ self.linear = nn.Linear(768, 512) # Reduce BERT output dimension
57
+
58
+ def forward(self, input_ids, attention_mask):
59
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
60
+ return self.linear(outputs.last_hidden_state[:, 0, :]) # Use [CLS] token
61
+
62
+ class SpriteGenerator(nn.Module):
63
+ def __init__(self, latent_dim=512):
64
+ super().__init__()
65
+
66
+ self.generator = nn.Sequential(
67
+ # Input: latent_dim x 1 x 1
68
+ nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
69
+ nn.BatchNorm2d(512),
70
+ nn.ReLU(True),
71
+ # 512 x 4 x 4
72
+ nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
73
+ nn.BatchNorm2d(256),
74
+ nn.ReLU(True),
75
+ # 256 x 8 x 8
76
+ nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
77
+ nn.BatchNorm2d(128),
78
+ nn.ReLU(True),
79
+ # 128 x 16 x 16
80
+ nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
81
+ nn.BatchNorm2d(64),
82
+ nn.ReLU(True),
83
+ # 64 x 32 x 32
84
+ nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
85
+ nn.Tanh()
86
+ # Output: 3 x 64 x 64
87
+ )
88
+
89
+ def forward(self, z):
90
+ z = z.unsqueeze(-1).unsqueeze(-1) # Add spatial dimensions
91
+ return self.generator(z)
92
+
93
+ class Animator2D(nn.Module):
94
+ def __init__(self):
95
+ super().__init__()
96
+ self.text_encoder = TextEncoder()
97
+ self.sprite_generator = SpriteGenerator()
98
+
99
+ def forward(self, input_ids, attention_mask):
100
+ text_features = self.text_encoder(input_ids, attention_mask)
101
+ generated_sprite = self.sprite_generator(text_features)
102
+ return generated_sprite
103
+
104
+ def train_model(num_epochs=100, batch_size=32, learning_rate=0.0002):
105
+ # Initialize dataset and dataloader
106
+ dataset = SpriteDataset()
107
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
108
+
109
+ # Initialize model and optimizer
110
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
111
+ model = Animator2D().to(device)
112
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.5, 0.999))
113
+ criterion = nn.MSELoss()
114
+
115
+ # Training loop
116
+ for epoch in range(num_epochs):
117
+ for batch_idx, batch in enumerate(dataloader):
118
+ # Move data to device
119
+ text_ids = batch['text_ids'].to(device)
120
+ text_mask = batch['text_mask'].to(device)
121
+ real_images = batch['image'].to(device)
122
+
123
+ # Forward pass
124
+ generated_images = model(text_ids, text_mask)
125
+
126
+ # Calculate loss
127
+ loss = criterion(generated_images, real_images)
128
+
129
+ # Backward pass and optimization
130
+ optimizer.zero_grad()
131
+ loss.backward()
132
+ optimizer.step()
133
+
134
+ if batch_idx % 100 == 0:
135
+ print(f"Epoch [{epoch}/{num_epochs}] Batch [{batch_idx}/{len(dataloader)}] Loss: {loss.item():.4f}")
136
+
137
+ # Save the trained model
138
+ torch.save(model.state_dict(), "animator2d_model.pth")
139
+ return model
140
+
141
+ def generate_sprite_animation(model, num_frames, description, action, direction):
142
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
143
+ model.eval()
144
+
145
+ # Prepare input text
146
+ text = f"{num_frames}-frame sprite animation of: {description}, that: {action}, facing: {direction}"
147
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
148
+ encoded_text = tokenizer(
149
+ text,
150
+ padding="max_length",
151
+ max_length=128,
152
+ truncation=True,
153
+ return_tensors="pt"
154
+ )
155
+
156
+ # Generate sprite sheet
157
+ with torch.no_grad():
158
+ text_ids = encoded_text['input_ids'].to(device)
159
+ text_mask = encoded_text['attention_mask'].to(device)
160
+ generated_sprite = model(text_ids, text_mask)
161
+
162
+ # Convert to image
163
+ generated_sprite = generated_sprite.cpu().squeeze(0)
164
+ generated_sprite = (generated_sprite + 1) / 2 # Denormalize
165
+ generated_sprite = transforms.ToPILImage()(generated_sprite)
166
+
167
+ return generated_sprite
168
+
169
+ # Example usage
170
+ if __name__ == "__main__":
171
+ # Train the model
172
+ model = train_model()
173
+
174
+ # Generate a new sprite animation
175
+ test_params = {
176
+ "num_frames": 17,
177
+ "description": "red-haired hobbit in green cape",
178
+ "action": "shoots with slingshot",
179
+ "direction": "East"
180
+ }
181
+
182
+ sprite_sheet = generate_sprite_animation(
183
+ model,
184
+ test_params["num_frames"],
185
+ test_params["description"],
186
+ test_params["action"],
187
+ test_params["direction"]
188
+ )
189
+ sprite_sheet.save("generated_sprite.png")