lovelyai999 commited on
Commit
54130af
·
verified ·
1 Parent(s): 29f7349

Upload 3 files

Browse files
Files changed (3) hide show
  1. best_model.pth +2 -2
  2. imageAI.py +248 -51
  3. optimizer.pth +2 -2
best_model.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:dd85c6f84bca45ea689fb0a0e402b5432a86bf2020046b52cf3b1e5aa19bf041
3
- size 17424939
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:134657576082cf3606a40722ead3d5e5df9111178f74f17584263de4082c300c
3
+ size 72353922
imageAI.py CHANGED
@@ -1,22 +1,22 @@
1
  try:
2
- import google.colab
3
- IN_COLAB = True
4
- from google.colab import drive,files
5
- from google.colab import output
6
- drive.mount('/gdrive')
7
- Gbase="/gdrive/MyDrive/generate/"
8
- cache_dir="/gdrive/MyDrive/hf/"
9
- import sys
10
- sys.path.append(Gbase)
11
  except:
12
- IN_COLAB = False
13
- Gbase="./"
14
- cache_dir="./hf/"
15
 
16
-
17
- import cv2,os
18
  import numpy as np
19
- import random,string
20
  import torch
21
  import torch.nn as nn
22
  import torch.nn.functional as F
@@ -28,36 +28,129 @@ print(f"Using device: {device}")
28
 
29
  IMAGE_SIZE = 64
30
  NUM_SAMPLES = 1000
31
- BATCH_SIZE = 4
32
  EPOCHS = 500
33
  LEARNING_RATE = 0.001
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  class SimpleModel(nn.Module):
36
  def __init__(self, path=None):
37
  super(SimpleModel, self).__init__()
38
- self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
39
- self.bn1 = nn.BatchNorm2d(32)
40
- self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
41
- self.bn2 = nn.BatchNorm2d(64)
42
- self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
43
- self.bn3 = nn.BatchNorm2d(128)
44
- self.pool = nn.MaxPool2d(2, 2)
45
- self.fc1 = nn.Linear(128 * 8 * 8, 512)
46
- self.fc2 = nn.Linear(512, 128)
47
- self.fc3 = nn.Linear(128, 1)
 
 
 
 
 
 
 
 
48
  self.dropout = nn.Dropout(0.5)
49
 
50
  if path and os.path.exists(path):
51
  self.load_state_dict(torch.load(path, map_location=device))
52
 
53
  def forward(self, x):
54
- x = self.pool(F.leaky_relu(self.bn1(self.conv1(x))))
55
- x = self.pool(F.leaky_relu(self.bn2(self.conv2(x))))
56
- x = self.pool(F.leaky_relu(self.bn3(self.conv3(x))))
57
- x = x.view(-1, 128 * 8 * 8)
58
- x = F.leaky_relu(self.fc1(x))
 
 
 
 
 
59
  x = self.dropout(x)
60
- x = F.leaky_relu(self.fc2(x))
61
  x = self.dropout(x)
62
  x = self.fc3(x)
63
  return x
@@ -66,43 +159,147 @@ class SimpleModel(nn.Module):
66
  self.eval()
67
  with torch.no_grad():
68
  if isinstance(image, str) and os.path.isfile(image):
69
- # 如果輸入是圖片文件路徑
70
- img = cv2.imread(image, cv2.IMREAD_GRAYSCALE)
71
  img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
72
  elif isinstance(image, np.ndarray):
73
- # 如果輸入是 numpy 數組
74
- if image.ndim == 3:
75
- img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
76
- else:
77
- img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
78
  img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
79
  else:
80
  raise ValueError("Input should be an image file path or a numpy array")
81
 
82
- img_tensor = torch.FloatTensor(img).unsqueeze(0).unsqueeze(0) / 255.0
83
  img_tensor = img_tensor.to(device)
84
  output = self(img_tensor).item()
85
 
86
- # 將輸出四捨五入到最接近的整數
87
  num_instructions = round(output)
88
 
89
- # 生成相應數量的繪圖指令
90
  instructions = []
91
  for _ in range(num_instructions):
92
- shape = random.choice(['line', 'rectangle', 'circle', 'ellipse', 'polygon'])
93
- if shape == 'line':
94
- instructions.append(f"cv2.line(image, {(random.randint(0, IMAGE_SIZE), random.randint(0, IMAGE_SIZE))}, {(random.randint(0, IMAGE_SIZE), random.randint(0, IMAGE_SIZE))}, {random.randint(0, 255)}, {random.randint(1, 3)})")
95
- elif shape == 'rectangle':
96
- instructions.append(f"cv2.rectangle(image, {(random.randint(0, IMAGE_SIZE-10), random.randint(0, IMAGE_SIZE-10))}, {(random.randint(10, IMAGE_SIZE), random.randint(10, IMAGE_SIZE))}, {random.randint(0, 255)}, {random.randint(1, 3)})")
97
  elif shape == 'circle':
98
- instructions.append(f"cv2.circle(image, {(random.randint(10, IMAGE_SIZE-10), random.randint(10, IMAGE_SIZE-10))}, {random.randint(5, 30)}, {random.randint(0, 255)}, {random.randint(1, 3)})")
99
  elif shape == 'ellipse':
100
- instructions.append(f"cv2.ellipse(image, {(random.randint(10, IMAGE_SIZE-10), random.randint(10, IMAGE_SIZE-10))}, {(random.randint(5, 30), random.randint(5, 30))}, {random.randint(0, 360)}, 0, 360, {random.randint(0, 255)}, {random.randint(1, 3)})")
101
  elif shape == 'polygon':
102
  num_points = random.randint(3, 6)
103
  points = [(random.randint(0, IMAGE_SIZE), random.randint(0, IMAGE_SIZE)) for _ in range(num_points)]
104
- instructions.append(f"cv2.polylines(image, [np.array({points})], True, {random.randint(0, 255)}, {random.randint(1, 3)})")
105
-
106
 
107
  return instructions
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  try:
2
+ import google.colab
3
+ IN_COLAB = True
4
+ from google.colab import drive, files
5
+ from google.colab import output
6
+ drive.mount('/gdrive')
7
+ Gbase = "/gdrive/MyDrive/generate/"
8
+ cache_dir = "/gdrive/MyDrive/hf/"
9
+ import sys
10
+ sys.path.append(Gbase)
11
  except:
12
+ IN_COLAB = False
13
+ Gbase = "./"
14
+ cache_dir = "./hf/"
15
 
16
+ import cv2
17
+ import os
18
  import numpy as np
19
+ import random
20
  import torch
21
  import torch.nn as nn
22
  import torch.nn.functional as F
 
28
 
29
  IMAGE_SIZE = 64
30
  NUM_SAMPLES = 1000
31
+ BATCH_SIZE = 200
32
  EPOCHS = 500
33
  LEARNING_RATE = 0.001
34
 
35
+ def generate_sample(num_shapes=1):
36
+ image = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 3), dtype=np.uint8)
37
+ instructions = []
38
+
39
+ for _ in range(num_shapes):
40
+ shape = random.choice(['rectangle', 'circle', 'ellipse', 'polygon'])
41
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
42
+
43
+ if shape == 'rectangle':
44
+ start_point = (random.randint(0, IMAGE_SIZE - 10), random.randint(0, IMAGE_SIZE - 10))
45
+ end_point = (start_point[0] + random.randint(10, IMAGE_SIZE - start_point[0]),
46
+ start_point[1] + random.randint(10, IMAGE_SIZE - start_point[1]))
47
+ cv2.rectangle(image, start_point, end_point, color, -1)
48
+ instructions.append(f"cv2.rectangle(image, {start_point}, {end_point}, {color}, -1)")
49
+
50
+ elif shape == 'circle':
51
+ center = (random.randint(10, IMAGE_SIZE - 10), random.randint(10, IMAGE_SIZE - 10))
52
+ radius = random.randint(5, min(center[0], center[1], IMAGE_SIZE - center[0], IMAGE_SIZE - center[1]))
53
+ cv2.circle(image, center, radius, color, -1)
54
+ instructions.append(f"cv2.circle(image, {center}, {radius}, {color}, -1)")
55
+
56
+ elif shape == 'ellipse':
57
+ center = (random.randint(10, IMAGE_SIZE - 10), random.randint(10, IMAGE_SIZE - 10))
58
+ axes = (random.randint(5, 30), random.randint(5, 30))
59
+ angle = random.randint(0, 360)
60
+ cv2.ellipse(image, center, axes, angle, 0, 360, color, -1)
61
+ instructions.append(f"cv2.ellipse(image, {center}, {axes}, {angle}, 0, 360, {color}, -1)")
62
+
63
+ elif shape == 'polygon':
64
+ num_points = random.randint(3, 6)
65
+ points = np.array([(random.randint(0, IMAGE_SIZE), random.randint(0, IMAGE_SIZE)) for _ in range(num_points)], np.int32)
66
+ points = points.reshape((-1, 1, 2))
67
+ cv2.fillPoly(image, [points], color)
68
+ instructions.append(f"cv2.fillPoly(image, [{points.tolist()}], {color})")
69
+
70
+ return {'image': image, 'instructions': instructions}
71
+
72
+ def generate_dataset(NUM_SAMPLES=NUM_SAMPLES, maxNumShape=3):
73
+ dataset = []
74
+ for _ in range(NUM_SAMPLES):
75
+ num_shapes = random.randint(1, maxNumShape)
76
+ sample = generate_sample(num_shapes=num_shapes)
77
+ dataset.append(sample)
78
+ return dataset
79
+
80
+ class ImageDataset(Dataset):
81
+ def __init__(self, dataset):
82
+ self.dataset = dataset
83
+
84
+ def __len__(self):
85
+ return len(self.dataset)
86
+
87
+ def __getitem__(self, idx):
88
+ sample = self.dataset[idx]
89
+ image = torch.FloatTensor(sample['image']).permute(2, 0, 1) / 255.0
90
+ return image, len(sample['instructions'])
91
+
92
+ class SelfAttention(nn.Module):
93
+ def __init__(self, in_channels):
94
+ super(SelfAttention, self).__init__()
95
+ self.query = nn.Conv2d(in_channels, in_channels // 8, 1)
96
+ self.key = nn.Conv2d(in_channels, in_channels // 8, 1)
97
+ self.value = nn.Conv2d(in_channels, in_channels, 1)
98
+ self.gamma = nn.Parameter(torch.zeros(1))
99
+
100
+ def forward(self, x):
101
+ batch_size, C, width, height = x.size()
102
+
103
+ proj_query = self.query(x).view(batch_size, -1, width * height).permute(0, 2, 1)
104
+ proj_key = self.key(x).view(batch_size, -1, width * height)
105
+ energy = torch.bmm(proj_query, proj_key)
106
+ attention = F.softmax(energy, dim=-1)
107
+
108
+ proj_value = self.value(x).view(batch_size, -1, width * height)
109
+ out = torch.bmm(proj_value, attention.permute(0, 2, 1))
110
+ out = out.view(batch_size, C, width, height)
111
+
112
+ out = self.gamma * out + x
113
+ return out
114
+
115
  class SimpleModel(nn.Module):
116
  def __init__(self, path=None):
117
  super(SimpleModel, self).__init__()
118
+
119
+ self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
120
+ self.bn1 = nn.BatchNorm2d(64)
121
+ self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
122
+ self.bn2 = nn.BatchNorm2d(128)
123
+ self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
124
+ self.bn3 = nn.BatchNorm2d(256)
125
+ self.conv4 = nn.Conv2d(256, 256, 3, padding=1)
126
+ self.bn4 = nn.BatchNorm2d(256)
127
+
128
+ self.pool = nn.AdaptiveAvgPool2d((8, 8))
129
+
130
+ self.attention = SelfAttention(256)
131
+
132
+ self.fc1 = nn.Linear(256 * 8 * 8, 1024)
133
+ self.fc2 = nn.Linear(1024, 256)
134
+ self.fc3 = nn.Linear(256, 1)
135
+
136
  self.dropout = nn.Dropout(0.5)
137
 
138
  if path and os.path.exists(path):
139
  self.load_state_dict(torch.load(path, map_location=device))
140
 
141
  def forward(self, x):
142
+ x = self.pool(F.mish(self.bn1(self.conv1(x))))
143
+ x = self.pool(F.mish(self.bn2(self.conv2(x))))
144
+ x = F.mish(self.bn3(self.conv3(x)))
145
+ x = F.mish(self.bn4(self.conv4(x)))
146
+
147
+ x = self.attention(x)
148
+
149
+ x = self.pool(x)
150
+ x = x.view(-1, 256 * 8 * 8)
151
+ x = F.mish(self.fc1(x))
152
  x = self.dropout(x)
153
+ x = F.mish(self.fc2(x))
154
  x = self.dropout(x)
155
  x = self.fc3(x)
156
  return x
 
159
  self.eval()
160
  with torch.no_grad():
161
  if isinstance(image, str) and os.path.isfile(image):
162
+ img = cv2.imread(image)
 
163
  img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
164
  elif isinstance(image, np.ndarray):
165
+ if image.ndim == 2:
166
+ img = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
 
 
 
167
  img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
168
  else:
169
  raise ValueError("Input should be an image file path or a numpy array")
170
 
171
+ img_tensor = torch.FloatTensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0
172
  img_tensor = img_tensor.to(device)
173
  output = self(img_tensor).item()
174
 
 
175
  num_instructions = round(output)
176
 
 
177
  instructions = []
178
  for _ in range(num_instructions):
179
+ shape = random.choice(['rectangle', 'circle', 'ellipse', 'polygon'])
180
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
181
+ if shape == 'rectangle':
182
+ instructions.append(f"cv2.rectangle(image, {(random.randint(0, IMAGE_SIZE-10), random.randint(0, IMAGE_SIZE-10))}, {(random.randint(10, IMAGE_SIZE), random.randint(10, IMAGE_SIZE))}, {color}, -1)")
 
183
  elif shape == 'circle':
184
+ instructions.append(f"cv2.circle(image, {(random.randint(10, IMAGE_SIZE-10), random.randint(10, IMAGE_SIZE-10))}, {random.randint(5, 30)}, {color}, -1)")
185
  elif shape == 'ellipse':
186
+ instructions.append(f"cv2.ellipse(image, {(random.randint(10, IMAGE_SIZE-10), random.randint(10, IMAGE_SIZE-10))}, {(random.randint(5, 30), random.randint(5, 30))}, {random.randint(0, 360)}, 0, 360, {color}, -1)")
187
  elif shape == 'polygon':
188
  num_points = random.randint(3, 6)
189
  points = [(random.randint(0, IMAGE_SIZE), random.randint(0, IMAGE_SIZE)) for _ in range(num_points)]
190
+ instructions.append(f"cv2.fillPoly(image, [np.array({points})], {color})")
 
191
 
192
  return instructions
193
 
194
+ def train(model, train_loader, optimizer, criterion):
195
+ model.train()
196
+ total_loss = 0
197
+ correct_predictions = 0
198
+ total_predictions = 0
199
+ for batch_idx, (data, target) in enumerate(train_loader):
200
+ data, target = data.to(device), target.float().to(device)
201
+ optimizer.zero_grad()
202
+ output = model(data).squeeze()
203
+ loss = criterion(output, target)
204
+ loss.backward()
205
+ optimizer.step()
206
+ total_loss += loss.item()
207
+
208
+ # Count correct predictions
209
+ predicted = torch.round(output)
210
+ correct_predictions += (predicted == target).sum().item()
211
+ total_predictions += target.size(0)
212
+
213
+ if batch_idx % 3000 == 0:
214
+ print(f'Train Batch {batch_idx}/{len(train_loader)} Loss: {loss.item():.6f}')
215
+
216
+ accuracy = correct_predictions / total_predictions
217
+ return total_loss / len(train_loader), accuracy
218
+
219
+ def test(model, test_loader, criterion, print_predictions=False):
220
+ model.eval()
221
+ test_loss = 0
222
+ correct_predictions = 0
223
+ total_predictions = 0
224
+ all_predictions = []
225
+ all_targets = []
226
+ with torch.no_grad():
227
+ for data, target in test_loader:
228
+ data, target = data.to(device), target.float().to(device)
229
+ output = model(data).squeeze()
230
+ test_loss += criterion(output, target).item()
231
+
232
+ predicted = torch.round(output)
233
+ correct_predictions += (predicted == target).sum().item()
234
+ total_predictions += target.size(0)
235
+
236
+ all_predictions.extend(output.cpu().numpy())
237
+ all_targets.extend(target.cpu().numpy())
238
+
239
+ test_loss /= len(test_loader)
240
+ accuracy = correct_predictions / total_predictions
241
+ print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {accuracy:.4f}')
242
+
243
+ if print_predictions:
244
+ print("Sample predictions:")
245
+ for pred, targ in zip(all_predictions[:10], all_targets[:10]):
246
+ print(f"Prediction: {pred:.2f}, Target: {targ:.2f}")
247
+
248
+ return test_loss, accuracy, all_predictions, all_targets
249
+
250
+ def train1(NUM_SAMPLES=NUM_SAMPLES, maxNumShape=1, EPOCHS=EPOCHS):
251
+ model = SimpleModel(path=os.path.join(Gbase, 'best_model.pth')).to(device)
252
+ optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
253
+
254
+ optimizer_path = os.path.join(Gbase, 'optimizer.pth')
255
+ if os.path.exists(optimizer_path):
256
+ print("Loading optimizer state...")
257
+ optimizer.load_state_dict(torch.load(optimizer_path, map_location=device))
258
+
259
+ criterion = nn.MSELoss()
260
+
261
+ seed = 618 * 382 * 33
262
+ random.seed(seed)
263
+ np.random.seed(seed)
264
+ torch.manual_seed(seed)
265
+ if torch.cuda.is_available():
266
+ torch.cuda.manual_seed(seed)
267
+
268
+ dataset = generate_dataset(NUM_SAMPLES=NUM_SAMPLES, maxNumShape=maxNumShape)
269
+ train_size = int(0.8 * len(dataset))
270
+ train_dataset = ImageDataset(dataset[:train_size])
271
+ test_dataset = ImageDataset(dataset[train_size:])
272
+
273
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
274
+ test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
275
+
276
+ best_loss = float('inf')
277
+ best_accuracy = 0
278
+
279
+ for epoch in range(EPOCHS):
280
+ print(f'Epoch {epoch+1}/{EPOCHS}')
281
+ train_loss, train_accuracy = train(model, train_loader, optimizer, criterion)
282
+ test_loss, test_accuracy, predictions, targets = test(model, test_loader, criterion, print_predictions=True)
283
+
284
+ print(f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}')
285
+ print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}')
286
+
287
+ if test_accuracy > best_accuracy or (test_accuracy == best_accuracy and test_loss < best_loss):
288
+ best_accuracy = test_accuracy
289
+ best_loss = test_loss
290
+ torch.save(model.state_dict(), os.path.join(Gbase, 'best_model.pth') ,_use_new_zipfile_serialization=False )
291
+ torch.save(optimizer.state_dict(), os.path.join(Gbase, 'optimizer.pth') ,_use_new_zipfile_serialization=False)
292
+ print(f"New best model saved with test accuracy: {best_accuracy:.4f} and test loss: {best_loss:.4f}")
293
+
294
+
295
+ if __name__ == "__main__":
296
+ train1(NUM_SAMPLES=1000 ,maxNumShape=1, EPOCHS=50)
297
+ train1(NUM_SAMPLES=1000 ,maxNumShape=1, EPOCHS=50)
298
+ train1(NUM_SAMPLES=1000 ,maxNumShape=1, EPOCHS=50)
299
+ train1(NUM_SAMPLES=10000 ,maxNumShape=3, EPOCHS=50)
300
+ train1(NUM_SAMPLES=10000 ,maxNumShape=3, EPOCHS=50)
301
+ train1(NUM_SAMPLES=10000 ,maxNumShape=5, EPOCHS=50)
302
+ while True:
303
+ train1(NUM_SAMPLES=100000 ,maxNumShape=10, EPOCHS=5)
304
+
305
+
optimizer.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:820f2334048c67a59ebe18ff8b4b749a5a679d0e9facde598af5b697adee0f20
3
- size 34845043
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:415436b37bc12580dca1104c9e1c4696dd84dd3926ff4af982670d2b2f8a0594
3
+ size 144695175