Spaces:
Running
Running
try: | |
import google.colab | |
IN_COLAB = True | |
from google.colab import drive, files | |
from google.colab import output | |
drive.mount('/gdrive') | |
Gbase = "/gdrive/MyDrive/generate/" | |
cache_dir = "/gdrive/MyDrive/hf/" | |
import sys | |
sys.path.append(Gbase) | |
except: | |
IN_COLAB = False | |
Gbase = "./" | |
cache_dir = "./hf/" | |
import cv2 | |
import os | |
import numpy as np | |
import random | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torch.utils.data import Dataset, DataLoader | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
IMAGE_SIZE = 64 | |
NUM_SAMPLES = 1000 | |
BATCH_SIZE = 200 | |
EPOCHS = 500 | |
LEARNING_RATE = 0.001 | |
def generate_sample(num_shapes=1): | |
image = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 3), dtype=np.uint8) | |
instructions = [] | |
for _ in range(num_shapes): | |
shape = random.choice(['rectangle', 'circle', 'ellipse', 'polygon']) | |
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) | |
if shape == 'rectangle': | |
start_point = (random.randint(0, IMAGE_SIZE - 10), random.randint(0, IMAGE_SIZE - 10)) | |
end_point = (start_point[0] + random.randint(10, IMAGE_SIZE - start_point[0]), | |
start_point[1] + random.randint(10, IMAGE_SIZE - start_point[1])) | |
cv2.rectangle(image, start_point, end_point, color, -1) | |
instructions.append(f"cv2.rectangle(image, {start_point}, {end_point}, {color}, -1)") | |
elif shape == 'circle': | |
center = (random.randint(10, IMAGE_SIZE - 10), random.randint(10, IMAGE_SIZE - 10)) | |
radius = random.randint(5, min(center[0], center[1], IMAGE_SIZE - center[0], IMAGE_SIZE - center[1])) | |
cv2.circle(image, center, radius, color, -1) | |
instructions.append(f"cv2.circle(image, {center}, {radius}, {color}, -1)") | |
elif shape == 'ellipse': | |
center = (random.randint(10, IMAGE_SIZE - 10), random.randint(10, IMAGE_SIZE - 10)) | |
axes = (random.randint(5, 30), random.randint(5, 30)) | |
angle = random.randint(0, 360) | |
cv2.ellipse(image, center, axes, angle, 0, 360, color, -1) | |
instructions.append(f"cv2.ellipse(image, {center}, {axes}, {angle}, 0, 360, {color}, -1)") | |
elif shape == 'polygon': | |
num_points = random.randint(3, 6) | |
points = np.array([(random.randint(0, IMAGE_SIZE), random.randint(0, IMAGE_SIZE)) for _ in range(num_points)], np.int32) | |
points = points.reshape((-1, 1, 2)) | |
cv2.fillPoly(image, [points], color) | |
instructions.append(f"cv2.fillPoly(image, [{points.tolist()}], {color})") | |
return {'image': image, 'instructions': instructions} | |
def generate_dataset(NUM_SAMPLES=NUM_SAMPLES, maxNumShape=3): | |
dataset = [] | |
for _ in range(NUM_SAMPLES): | |
num_shapes = random.randint(1, maxNumShape) | |
sample = generate_sample(num_shapes=num_shapes) | |
dataset.append(sample) | |
return dataset | |
class ImageDataset(Dataset): | |
def __init__(self, dataset): | |
self.dataset = dataset | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, idx): | |
sample = self.dataset[idx] | |
image = torch.FloatTensor(sample['image']).permute(2, 0, 1) / 255.0 | |
return image, len(sample['instructions']) | |
class SelfAttention(nn.Module): | |
def __init__(self, in_channels): | |
super(SelfAttention, self).__init__() | |
self.query = nn.Conv2d(in_channels, in_channels // 8, 1) | |
self.key = nn.Conv2d(in_channels, in_channels // 8, 1) | |
self.value = nn.Conv2d(in_channels, in_channels, 1) | |
self.gamma = nn.Parameter(torch.zeros(1)) | |
def forward(self, x): | |
batch_size, C, width, height = x.size() | |
proj_query = self.query(x).view(batch_size, -1, width * height).permute(0, 2, 1) | |
proj_key = self.key(x).view(batch_size, -1, width * height) | |
energy = torch.bmm(proj_query, proj_key) | |
attention = F.softmax(energy, dim=-1) | |
proj_value = self.value(x).view(batch_size, -1, width * height) | |
out = torch.bmm(proj_value, attention.permute(0, 2, 1)) | |
out = out.view(batch_size, C, width, height) | |
out = self.gamma * out + x | |
return out | |
class SimpleModel(nn.Module): | |
def __init__(self, path=None): | |
super(SimpleModel, self).__init__() | |
self.conv1 = nn.Conv2d(3, 64, 3, padding=1) | |
self.bn1 = nn.BatchNorm2d(64) | |
self.conv2 = nn.Conv2d(64, 128, 3, padding=1) | |
self.bn2 = nn.BatchNorm2d(128) | |
self.conv3 = nn.Conv2d(128, 256, 3, padding=1) | |
self.bn3 = nn.BatchNorm2d(256) | |
self.conv4 = nn.Conv2d(256, 256, 3, padding=1) | |
self.bn4 = nn.BatchNorm2d(256) | |
self.pool = nn.AdaptiveAvgPool2d((8, 8)) | |
self.attention = SelfAttention(256) | |
self.fc1 = nn.Linear(256 * 8 * 8, 1024) | |
self.fc2 = nn.Linear(1024, 256) | |
self.fc3 = nn.Linear(256, 1) | |
self.dropout = nn.Dropout(0.5) | |
if path and os.path.exists(path): | |
self.load_state_dict(torch.load(path, map_location=device)) | |
def forward(self, x): | |
x = self.pool(F.mish(self.bn1(self.conv1(x)))) | |
x = self.pool(F.mish(self.bn2(self.conv2(x)))) | |
x = F.mish(self.bn3(self.conv3(x))) | |
x = F.mish(self.bn4(self.conv4(x))) | |
x = self.attention(x) | |
x = self.pool(x) | |
x = x.view(-1, 256 * 8 * 8) | |
x = F.mish(self.fc1(x)) | |
x = self.dropout(x) | |
x = F.mish(self.fc2(x)) | |
x = self.dropout(x) | |
x = self.fc3(x) | |
return x | |
def predict(self, image): | |
self.eval() | |
with torch.no_grad(): | |
if isinstance(image, str) and os.path.isfile(image): | |
img = cv2.imread(image) | |
img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)) | |
elif isinstance(image, np.ndarray): | |
if image.ndim == 2: | |
img = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) | |
img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)) | |
else: | |
raise ValueError("Input should be an image file path or a numpy array") | |
img_tensor = torch.FloatTensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0 | |
img_tensor = img_tensor.to(device) | |
output = self(img_tensor).item() | |
num_instructions = round(output) | |
instructions = [] | |
for _ in range(num_instructions): | |
shape = random.choice(['rectangle', 'circle', 'ellipse', 'polygon']) | |
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) | |
if shape == 'rectangle': | |
instructions.append(f"cv2.rectangle(image, {(random.randint(0, IMAGE_SIZE-10), random.randint(0, IMAGE_SIZE-10))}, {(random.randint(10, IMAGE_SIZE), random.randint(10, IMAGE_SIZE))}, {color}, -1)") | |
elif shape == 'circle': | |
instructions.append(f"cv2.circle(image, {(random.randint(10, IMAGE_SIZE-10), random.randint(10, IMAGE_SIZE-10))}, {random.randint(5, 30)}, {color}, -1)") | |
elif shape == 'ellipse': | |
instructions.append(f"cv2.ellipse(image, {(random.randint(10, IMAGE_SIZE-10), random.randint(10, IMAGE_SIZE-10))}, {(random.randint(5, 30), random.randint(5, 30))}, {random.randint(0, 360)}, 0, 360, {color}, -1)") | |
elif shape == 'polygon': | |
num_points = random.randint(3, 6) | |
points = [(random.randint(0, IMAGE_SIZE), random.randint(0, IMAGE_SIZE)) for _ in range(num_points)] | |
instructions.append(f"cv2.fillPoly(image, [np.array({points})], {color})") | |
return instructions | |
def train(model, train_loader, optimizer, criterion): | |
model.train() | |
total_loss = 0 | |
correct_predictions = 0 | |
total_predictions = 0 | |
for batch_idx, (data, target) in enumerate(train_loader): | |
data, target = data.to(device), target.float().to(device) | |
optimizer.zero_grad() | |
output = model(data).squeeze() | |
loss = criterion(output, target) | |
loss.backward() | |
optimizer.step() | |
total_loss += loss.item() | |
# Count correct predictions | |
predicted = torch.round(output) | |
correct_predictions += (predicted == target).sum().item() | |
total_predictions += target.size(0) | |
if batch_idx % 3000 == 0: | |
print(f'Train Batch {batch_idx}/{len(train_loader)} Loss: {loss.item():.6f}') | |
accuracy = correct_predictions / total_predictions | |
return total_loss / len(train_loader), accuracy | |
def test(model, test_loader, criterion, print_predictions=False): | |
model.eval() | |
test_loss = 0 | |
correct_predictions = 0 | |
total_predictions = 0 | |
all_predictions = [] | |
all_targets = [] | |
with torch.no_grad(): | |
for data, target in test_loader: | |
data, target = data.to(device), target.float().to(device) | |
output = model(data).squeeze() | |
test_loss += criterion(output, target).item() | |
predicted = torch.round(output) | |
correct_predictions += (predicted == target).sum().item() | |
total_predictions += target.size(0) | |
all_predictions.extend(output.cpu().numpy()) | |
all_targets.extend(target.cpu().numpy()) | |
test_loss /= len(test_loader) | |
accuracy = correct_predictions / total_predictions | |
print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {accuracy:.4f}') | |
if print_predictions: | |
print("Sample predictions:") | |
for pred, targ in zip(all_predictions[:10], all_targets[:10]): | |
print(f"Prediction: {pred:.2f}, Target: {targ:.2f}") | |
return test_loss, accuracy, all_predictions, all_targets | |
def train1(NUM_SAMPLES=NUM_SAMPLES, maxNumShape=1, EPOCHS=EPOCHS): | |
model = SimpleModel(path=os.path.join(Gbase, 'best_model.pth')).to(device) | |
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) | |
optimizer_path = os.path.join(Gbase, 'optimizer.pth') | |
if os.path.exists(optimizer_path): | |
print("Loading optimizer state...") | |
optimizer.load_state_dict(torch.load(optimizer_path, map_location=device)) | |
criterion = nn.MSELoss() | |
seed = 618 * 382 * 33 | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(seed) | |
dataset = generate_dataset(NUM_SAMPLES=NUM_SAMPLES, maxNumShape=maxNumShape) | |
train_size = int(0.8 * len(dataset)) | |
train_dataset = ImageDataset(dataset[:train_size]) | |
test_dataset = ImageDataset(dataset[train_size:]) | |
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) | |
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True) | |
best_loss = float('inf') | |
best_accuracy = 0 | |
for epoch in range(EPOCHS): | |
print(f'Epoch {epoch+1}/{EPOCHS}') | |
train_loss, train_accuracy = train(model, train_loader, optimizer, criterion) | |
test_loss, test_accuracy, predictions, targets = test(model, test_loader, criterion, print_predictions=True) | |
print(f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}') | |
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}') | |
if test_accuracy > best_accuracy or (test_accuracy == best_accuracy and test_loss < best_loss): | |
best_accuracy = test_accuracy | |
best_loss = test_loss | |
torch.save(model.state_dict(), os.path.join(Gbase, 'best_model.pth') ,_use_new_zipfile_serialization=False ) | |
torch.save(optimizer.state_dict(), os.path.join(Gbase, 'optimizer.pth') ,_use_new_zipfile_serialization=False) | |
print(f"New best model saved with test accuracy: {best_accuracy:.4f} and test loss: {best_loss:.4f}") | |
if __name__ == "__main__": | |
train1(NUM_SAMPLES=1000 ,maxNumShape=1, EPOCHS=50) | |
train1(NUM_SAMPLES=1000 ,maxNumShape=1, EPOCHS=50) | |
train1(NUM_SAMPLES=1000 ,maxNumShape=1, EPOCHS=50) | |
train1(NUM_SAMPLES=10000 ,maxNumShape=3, EPOCHS=50) | |
train1(NUM_SAMPLES=10000 ,maxNumShape=3, EPOCHS=50) | |
train1(NUM_SAMPLES=10000 ,maxNumShape=5, EPOCHS=50) | |
while True: | |
train1(NUM_SAMPLES=100000 ,maxNumShape=10, EPOCHS=5) | |