Spaces:
Sleeping
Sleeping
Commit
·
d9041f1
1
Parent(s):
bd0c530
output fix
Browse files- __pycache__/model.cpython-310.pyc +0 -0
- __pycache__/solver.cpython-310.pyc +0 -0
- app.py +64 -60
- examples/Captura de pantalla de 2022-10-09 15-30-40.png +0 -0
- examples/Captura de pantalla de 2022-10-15 17-57-27.png +0 -0
- examples/Captura de pantalla de 2022-10-15 17-57-38.png +0 -0
- helper/__pycache__/data_setup.cpython-310.pyc +0 -0
- helper/__pycache__/engine.cpython-310.pyc +0 -0
- helper/__pycache__/helper_functions.cpython-310.pyc +0 -0
- helper/__pycache__/predictions.cpython-310.pyc +0 -0
- helper/data_setup.py +0 -66
- helper/engine.py +0 -195
- helper/helper_functions.py +0 -294
- helper/model_builder.py +0 -56
- helper/predictions.py +0 -83
- helper/train.py +0 -62
- helper/utils.py +0 -35
- model.py +0 -68
- model/2model28.h5 +0 -3
- solver.py +71 -47
- src/__pycache__/solve.cpython-310.pyc +0 -0
- src/__pycache__/solve.cpython-39.pyc +0 -0
- src/__pycache__/tesseract.cpython-310.pyc +0 -0
- src/__pycache__/tesseract.cpython-39.pyc +0 -0
- src/solve.py +6 -6
- words/descarga.png +0 -0
- words/words1.png +0 -0
__pycache__/model.cpython-310.pyc
DELETED
Binary file (2.1 kB)
|
|
__pycache__/solver.cpython-310.pyc
DELETED
Binary file (5.95 kB)
|
|
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
### 1. Imports and class names setup ###
|
2 |
from ast import Interactive
|
3 |
import gradio as gr
|
4 |
import os
|
@@ -13,6 +13,8 @@ import pytesseract
|
|
13 |
import re
|
14 |
import shutil
|
15 |
import solver
|
|
|
|
|
16 |
|
17 |
from model import create_model
|
18 |
from timeit import default_timer as timer
|
@@ -30,81 +32,83 @@ def parse_args() -> argparse.Namespace:
|
|
30 |
action='store_false')
|
31 |
return parser.parse_args()
|
32 |
|
|
|
33 |
def set_example_image(example: list) -> dict:
|
34 |
return gr.Image.update(value=example[0])
|
35 |
|
36 |
|
|
|
|
|
|
|
|
|
37 |
# Setup class names
|
38 |
-
with open("class_names.txt", "r") as f:
|
39 |
-
class_names = [names.strip() for names in
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
)
|
47 |
-
|
48 |
-
model_created = model_base(input_shape=1, hidden_units=10, output_shape=len(class_names))
|
49 |
-
|
50 |
-
model_1 = model_created.load_state_dict(
|
51 |
-
torch.load(
|
52 |
-
f="model_1.pth",
|
53 |
-
map_location=torch.device("cpu"), # load to CPU
|
54 |
-
)
|
55 |
-
)
|
56 |
|
57 |
|
58 |
def main():
|
59 |
args = parse_args()
|
60 |
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
61 |
-
print('*** Now using %s.'%(args.device))
|
62 |
|
63 |
with gr.Blocks(theme=args.theme, css='style.css') as demo:
|
64 |
gr.Markdown('''# World Puzzle Solver 🧩''')
|
|
|
65 |
|
66 |
with gr.Box():
|
67 |
-
gr.Markdown(
|
68 |
-
|
|
|
69 |
with gr.Box():
|
70 |
with gr.Column():
|
71 |
-
gr.Markdown('''Images 🖼️''')
|
72 |
with gr.Row():
|
73 |
-
input_board= gr.Image(label='Board',
|
74 |
-
|
75 |
-
|
76 |
with gr.Row():
|
77 |
crop_board_button = gr.Button('Crop Board ✂️')
|
78 |
with gr.Row():
|
79 |
input_words = gr.Image(label='Words',
|
80 |
-
|
81 |
-
|
82 |
with gr.Row():
|
83 |
crop_words_button = gr.Button('Crop Words ✂️')
|
84 |
-
with gr.
|
85 |
-
|
86 |
-
paths = [["examples/" + example]
|
|
|
87 |
example_images = gr.Dataset(components=([input_board]),
|
88 |
-
|
89 |
-
|
|
|
90 |
|
91 |
with gr.Box():
|
92 |
with gr.Column():
|
93 |
gr.Markdown('''Cropped Images ✂️''')
|
94 |
with gr.Row():
|
95 |
cropped_board = gr.Image(label='Board Cropped',
|
96 |
-
|
97 |
-
|
98 |
instyle = gr.Variable()
|
99 |
with gr.Row():
|
100 |
cropped_words = gr.Image(label='Words Cropped',
|
101 |
-
|
102 |
-
|
103 |
instyle = gr.Variable()
|
104 |
with gr.Row():
|
105 |
find_words_button = gr.Button('Find Words 🔍')
|
106 |
with gr.Row():
|
107 |
-
words_found = gr.Textbox(
|
|
|
108 |
with gr.Row():
|
109 |
solve_button = gr.Button('Solve! 📝')
|
110 |
|
@@ -112,36 +116,36 @@ def main():
|
|
112 |
with gr.Column():
|
113 |
gr.Markdown('''Solution ✅''')
|
114 |
with gr.Row():
|
115 |
-
board_solved = gr.Image(
|
116 |
-
|
117 |
-
|
118 |
-
with gr.
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
124 |
|
125 |
crop_board_button.click(fn=None,
|
126 |
-
|
127 |
-
|
128 |
crop_words_button.click(fn=None,
|
129 |
-
|
130 |
-
|
131 |
find_words_button.click(solver.get_words,
|
132 |
-
|
133 |
-
|
134 |
solve_button.click(solver.solve_puzzle,
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
|
139 |
example_images.click(fn=set_example_image,
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
|
144 |
-
|
145 |
demo.launch(
|
146 |
enable_queue=args.enable_queue,
|
147 |
server_port=args.port,
|
|
|
1 |
+
### 1. Imports and class names setup ###
|
2 |
from ast import Interactive
|
3 |
import gradio as gr
|
4 |
import os
|
|
|
13 |
import re
|
14 |
import shutil
|
15 |
import solver
|
16 |
+
import glob
|
17 |
+
from PIL import Image
|
18 |
|
19 |
from model import create_model
|
20 |
from timeit import default_timer as timer
|
|
|
32 |
action='store_false')
|
33 |
return parser.parse_args()
|
34 |
|
35 |
+
|
36 |
def set_example_image(example: list) -> dict:
|
37 |
return gr.Image.update(value=example[0])
|
38 |
|
39 |
|
40 |
+
def update_dataset(example: list) -> dict:
|
41 |
+
return gr.Gallery.update(value=example[0])
|
42 |
+
|
43 |
+
|
44 |
# Setup class names
|
45 |
+
with open("class_names.txt", "r") as f: # reading them in from class_names.txt
|
46 |
+
class_names = [names.strip() for names in f.readlines()]
|
47 |
+
|
48 |
+
|
49 |
+
def get_images():
|
50 |
+
images_list = []
|
51 |
+
for filename in glob.glob('wordsPuzzle/*.jpg'): # assuming png
|
52 |
+
im = Image.open(filename)
|
53 |
+
images_list.append(im)
|
54 |
+
return images_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
|
57 |
def main():
|
58 |
args = parse_args()
|
59 |
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
60 |
+
print('*** Now using %s.' % (args.device))
|
61 |
|
62 |
with gr.Blocks(theme=args.theme, css='style.css') as demo:
|
63 |
gr.Markdown('''# World Puzzle Solver 🧩''')
|
64 |
+
gr.Markdown('''## (Works in Spanish too!) 🇪🇸''')
|
65 |
|
66 |
with gr.Box():
|
67 |
+
gr.Markdown(
|
68 |
+
'''### Insert a Word Puzzle Image in both boxes and crop the board and words''')
|
69 |
+
with gr.Row():
|
70 |
with gr.Box():
|
71 |
with gr.Column():
|
72 |
+
gr.Markdown('''Images 🖼️''')
|
73 |
with gr.Row():
|
74 |
+
input_board = gr.Image(label='Board',
|
75 |
+
type='filepath',
|
76 |
+
interactive=True,)
|
77 |
with gr.Row():
|
78 |
crop_board_button = gr.Button('Crop Board ✂️')
|
79 |
with gr.Row():
|
80 |
input_words = gr.Image(label='Words',
|
81 |
+
type='filepath',
|
82 |
+
interactive=True, height="auto")
|
83 |
with gr.Row():
|
84 |
crop_words_button = gr.Button('Crop Words ✂️')
|
85 |
+
with gr.Row():
|
86 |
+
# Create examples list from "examples/" directory
|
87 |
+
paths = [["examples/" + example]
|
88 |
+
for example in os.listdir("examples")]
|
89 |
example_images = gr.Dataset(components=([input_board]),
|
90 |
+
samples=[[path]
|
91 |
+
for path in paths],
|
92 |
+
label='Image Examples (Drag and drop into both boxes) then crop using the tool button')
|
93 |
|
94 |
with gr.Box():
|
95 |
with gr.Column():
|
96 |
gr.Markdown('''Cropped Images ✂️''')
|
97 |
with gr.Row():
|
98 |
cropped_board = gr.Image(label='Board Cropped',
|
99 |
+
type='filepath',
|
100 |
+
interactive=False, height="auto")
|
101 |
instyle = gr.Variable()
|
102 |
with gr.Row():
|
103 |
cropped_words = gr.Image(label='Words Cropped',
|
104 |
+
type='filepath',
|
105 |
+
interactive=False)
|
106 |
instyle = gr.Variable()
|
107 |
with gr.Row():
|
108 |
find_words_button = gr.Button('Find Words 🔍')
|
109 |
with gr.Row():
|
110 |
+
words_found = gr.Textbox(
|
111 |
+
label='Words detected (edit if wrong)', interactive=True, value='')
|
112 |
with gr.Row():
|
113 |
solve_button = gr.Button('Solve! 📝')
|
114 |
|
|
|
116 |
with gr.Column():
|
117 |
gr.Markdown('''Solution ✅''')
|
118 |
with gr.Row():
|
119 |
+
board_solved = gr.Image(
|
120 |
+
type='filepath',
|
121 |
+
interactive=False)
|
122 |
+
with gr.Row():
|
123 |
+
show_words_board = gr.Button(
|
124 |
+
'Show words seperately 📝')
|
125 |
+
with gr.Row():
|
126 |
+
gallery = gr.Gallery(
|
127 |
+
label=None, show_label=True, elem_id="gallery"
|
128 |
+
).style(grid=[4], height="auto")
|
129 |
+
|
130 |
|
131 |
crop_board_button.click(fn=None,
|
132 |
+
inputs=[input_board],
|
133 |
+
outputs=[cropped_board])
|
134 |
crop_words_button.click(fn=None,
|
135 |
+
inputs=[input_words],
|
136 |
+
outputs=[cropped_words])
|
137 |
find_words_button.click(solver.get_words,
|
138 |
+
inputs=cropped_words,
|
139 |
+
outputs=words_found)
|
140 |
solve_button.click(solver.solve_puzzle,
|
141 |
+
inputs=[cropped_board, words_found],
|
142 |
+
outputs=board_solved)
|
|
|
143 |
|
144 |
example_images.click(fn=set_example_image,
|
145 |
+
inputs=example_images,
|
146 |
+
outputs=example_images.components)
|
147 |
+
show_words_board.click(get_images, None, gallery)
|
148 |
|
|
|
149 |
demo.launch(
|
150 |
enable_queue=args.enable_queue,
|
151 |
server_port=args.port,
|
examples/Captura de pantalla de 2022-10-09 15-30-40.png
DELETED
Binary file (178 kB)
|
|
examples/Captura de pantalla de 2022-10-15 17-57-27.png
ADDED
![]() |
examples/Captura de pantalla de 2022-10-15 17-57-38.png
ADDED
![]() |
helper/__pycache__/data_setup.cpython-310.pyc
DELETED
Binary file (1.97 kB)
|
|
helper/__pycache__/engine.cpython-310.pyc
DELETED
Binary file (4.95 kB)
|
|
helper/__pycache__/helper_functions.cpython-310.pyc
DELETED
Binary file (8.32 kB)
|
|
helper/__pycache__/predictions.cpython-310.pyc
DELETED
Binary file (2.3 kB)
|
|
helper/data_setup.py
DELETED
@@ -1,66 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Contains functionality for creating PyTorch DataLoaders for
|
3 |
-
image classification data.
|
4 |
-
"""
|
5 |
-
import os
|
6 |
-
|
7 |
-
from torchvision import datasets, transforms
|
8 |
-
from torch.utils.data import DataLoader
|
9 |
-
|
10 |
-
NUM_WORKERS = os.cpu_count()
|
11 |
-
|
12 |
-
def create_dataloaders(
|
13 |
-
train_dir: str,
|
14 |
-
test_dir: str,
|
15 |
-
train_transform: transforms.Compose,
|
16 |
-
test_transform: transforms.Compose,
|
17 |
-
batch_size: int,
|
18 |
-
num_workers: int=NUM_WORKERS
|
19 |
-
):
|
20 |
-
"""Creates training and testing DataLoaders.
|
21 |
-
|
22 |
-
Takes in a training directory and testing directory path and turns
|
23 |
-
them into PyTorch Datasets and then into PyTorch DataLoaders.
|
24 |
-
|
25 |
-
Args:
|
26 |
-
train_dir: Path to training directory.
|
27 |
-
test_dir: Path to testing directory.
|
28 |
-
transform: torchvision transforms to perform on training and testing data.
|
29 |
-
batch_size: Number of samples per batch in each of the DataLoaders.
|
30 |
-
num_workers: An integer for number of workers per DataLoader.
|
31 |
-
|
32 |
-
Returns:
|
33 |
-
A tuple of (train_dataloader, test_dataloader, class_names).
|
34 |
-
Where class_names is a list of the target classes.
|
35 |
-
Example usage:
|
36 |
-
train_dataloader, test_dataloader, class_names = \
|
37 |
-
= create_dataloaders(train_dir=path/to/train_dir,
|
38 |
-
test_dir=path/to/test_dir,
|
39 |
-
transform=some_transform,
|
40 |
-
batch_size=32,
|
41 |
-
num_workers=4)
|
42 |
-
"""
|
43 |
-
# Use ImageFolder to create dataset(s)
|
44 |
-
train_data = datasets.ImageFolder(train_dir, transform=train_transform)
|
45 |
-
test_data = datasets.ImageFolder(test_dir, transform=test_transform)
|
46 |
-
|
47 |
-
# Get class names
|
48 |
-
class_names = train_data.classes
|
49 |
-
|
50 |
-
# Turn images into data loaders
|
51 |
-
train_dataloader = DataLoader(
|
52 |
-
train_data,
|
53 |
-
batch_size=batch_size,
|
54 |
-
shuffle=True,
|
55 |
-
num_workers=num_workers,
|
56 |
-
pin_memory=True,
|
57 |
-
)
|
58 |
-
test_dataloader = DataLoader(
|
59 |
-
test_data,
|
60 |
-
batch_size=batch_size,
|
61 |
-
shuffle=False,
|
62 |
-
num_workers=num_workers,
|
63 |
-
pin_memory=True,
|
64 |
-
)
|
65 |
-
|
66 |
-
return train_dataloader, test_dataloader, class_names
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
helper/engine.py
DELETED
@@ -1,195 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Contains functions for training and testing a PyTorch model.
|
3 |
-
"""
|
4 |
-
import torch
|
5 |
-
|
6 |
-
from tqdm.auto import tqdm
|
7 |
-
from typing import Dict, List, Tuple
|
8 |
-
|
9 |
-
def train_step(model: torch.nn.Module,
|
10 |
-
dataloader: torch.utils.data.DataLoader,
|
11 |
-
loss_fn: torch.nn.Module,
|
12 |
-
optimizer: torch.optim.Optimizer,
|
13 |
-
device: torch.device) -> Tuple[float, float]:
|
14 |
-
"""Trains a PyTorch model for a single epoch.
|
15 |
-
|
16 |
-
Turns a target PyTorch model to training mode and then
|
17 |
-
runs through all of the required training steps (forward
|
18 |
-
pass, loss calculation, optimizer step).
|
19 |
-
|
20 |
-
Args:
|
21 |
-
model: A PyTorch model to be trained.
|
22 |
-
dataloader: A DataLoader instance for the model to be trained on.
|
23 |
-
loss_fn: A PyTorch loss function to minimize.
|
24 |
-
optimizer: A PyTorch optimizer to help minimize the loss function.
|
25 |
-
device: A target device to compute on (e.g. "cuda" or "cpu").
|
26 |
-
|
27 |
-
Returns:
|
28 |
-
A tuple of training loss and training accuracy metrics.
|
29 |
-
In the form (train_loss, train_accuracy). For example:
|
30 |
-
|
31 |
-
(0.1112, 0.8743)
|
32 |
-
"""
|
33 |
-
# Put model in train mode
|
34 |
-
model.train()
|
35 |
-
|
36 |
-
# Setup train loss and train accuracy values
|
37 |
-
train_loss, train_acc = 0, 0
|
38 |
-
|
39 |
-
# Loop through data loader data batches
|
40 |
-
for batch, (X, y) in enumerate(dataloader):
|
41 |
-
# Send data to target device
|
42 |
-
X, y = X.to(device), y.to(device)
|
43 |
-
|
44 |
-
# 1. Forward pass
|
45 |
-
y_pred = model(X)
|
46 |
-
|
47 |
-
# 2. Calculate and accumulate loss
|
48 |
-
loss = loss_fn(y_pred, y)
|
49 |
-
train_loss += loss.item()
|
50 |
-
|
51 |
-
# 3. Optimizer zero grad
|
52 |
-
optimizer.zero_grad()
|
53 |
-
|
54 |
-
# 4. Loss backward
|
55 |
-
loss.backward()
|
56 |
-
|
57 |
-
# 5. Optimizer step
|
58 |
-
optimizer.step()
|
59 |
-
|
60 |
-
# Calculate and accumulate accuracy metric across all batches
|
61 |
-
y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
|
62 |
-
train_acc += (y_pred_class == y).sum().item()/len(y_pred)
|
63 |
-
|
64 |
-
# Adjust metrics to get average loss and accuracy per batch
|
65 |
-
train_loss = train_loss / len(dataloader)
|
66 |
-
train_acc = train_acc / len(dataloader)
|
67 |
-
return train_loss, train_acc
|
68 |
-
|
69 |
-
def test_step(model: torch.nn.Module,
|
70 |
-
dataloader: torch.utils.data.DataLoader,
|
71 |
-
loss_fn: torch.nn.Module,
|
72 |
-
device: torch.device) -> Tuple[float, float]:
|
73 |
-
"""Tests a PyTorch model for a single epoch.
|
74 |
-
|
75 |
-
Turns a target PyTorch model to "eval" mode and then performs
|
76 |
-
a forward pass on a testing dataset.
|
77 |
-
|
78 |
-
Args:
|
79 |
-
model: A PyTorch model to be tested.
|
80 |
-
dataloader: A DataLoader instance for the model to be tested on.
|
81 |
-
loss_fn: A PyTorch loss function to calculate loss on the test data.
|
82 |
-
device: A target device to compute on (e.g. "cuda" or "cpu").
|
83 |
-
|
84 |
-
Returns:
|
85 |
-
A tuple of testing loss and testing accuracy metrics.
|
86 |
-
In the form (test_loss, test_accuracy). For example:
|
87 |
-
|
88 |
-
(0.0223, 0.8985)
|
89 |
-
"""
|
90 |
-
# Put model in eval mode
|
91 |
-
model.eval()
|
92 |
-
|
93 |
-
# Setup test loss and test accuracy values
|
94 |
-
test_loss, test_acc = 0, 0
|
95 |
-
|
96 |
-
# Turn on inference context manager
|
97 |
-
with torch.inference_mode():
|
98 |
-
# Loop through DataLoader batches
|
99 |
-
for batch, (X, y) in enumerate(dataloader):
|
100 |
-
# Send data to target device
|
101 |
-
X, y = X.to(device), y.to(device)
|
102 |
-
|
103 |
-
# 1. Forward pass
|
104 |
-
test_pred_logits = model(X)
|
105 |
-
|
106 |
-
# 2. Calculate and accumulate loss
|
107 |
-
loss = loss_fn(test_pred_logits, y)
|
108 |
-
test_loss += loss.item()
|
109 |
-
|
110 |
-
# Calculate and accumulate accuracy
|
111 |
-
test_pred_labels = test_pred_logits.argmax(dim=1)
|
112 |
-
test_acc += ((test_pred_labels == y).sum().item()/len(test_pred_labels))
|
113 |
-
|
114 |
-
# Adjust metrics to get average loss and accuracy per batch
|
115 |
-
test_loss = test_loss / len(dataloader)
|
116 |
-
test_acc = test_acc / len(dataloader)
|
117 |
-
return test_loss, test_acc
|
118 |
-
|
119 |
-
def train(model: torch.nn.Module,
|
120 |
-
train_dataloader: torch.utils.data.DataLoader,
|
121 |
-
test_dataloader: torch.utils.data.DataLoader,
|
122 |
-
optimizer: torch.optim.Optimizer,
|
123 |
-
loss_fn: torch.nn.Module,
|
124 |
-
epochs: int,
|
125 |
-
device: torch.device) -> Dict[str, List]:
|
126 |
-
"""Trains and tests a PyTorch model.
|
127 |
-
|
128 |
-
Passes a target PyTorch models through train_step() and test_step()
|
129 |
-
functions for a number of epochs, training and testing the model
|
130 |
-
in the same epoch loop.
|
131 |
-
|
132 |
-
Calculates, prints and stores evaluation metrics throughout.
|
133 |
-
|
134 |
-
Args:
|
135 |
-
model: A PyTorch model to be trained and tested.
|
136 |
-
train_dataloader: A DataLoader instance for the model to be trained on.
|
137 |
-
test_dataloader: A DataLoader instance for the model to be tested on.
|
138 |
-
optimizer: A PyTorch optimizer to help minimize the loss function.
|
139 |
-
loss_fn: A PyTorch loss function to calculate loss on both datasets.
|
140 |
-
epochs: An integer indicating how many epochs to train for.
|
141 |
-
device: A target device to compute on (e.g. "cuda" or "cpu").
|
142 |
-
|
143 |
-
Returns:
|
144 |
-
A dictionary of training and testing loss as well as training and
|
145 |
-
testing accuracy metrics. Each metric has a value in a list for
|
146 |
-
each epoch.
|
147 |
-
In the form: {train_loss: [...],
|
148 |
-
train_acc: [...],
|
149 |
-
test_loss: [...],
|
150 |
-
test_acc: [...]}
|
151 |
-
For example if training for epochs=2:
|
152 |
-
{train_loss: [2.0616, 1.0537],
|
153 |
-
train_acc: [0.3945, 0.3945],
|
154 |
-
test_loss: [1.2641, 1.5706],
|
155 |
-
test_acc: [0.3400, 0.2973]}
|
156 |
-
"""
|
157 |
-
# Create empty results dictionary
|
158 |
-
results = {"train_loss": [],
|
159 |
-
"train_acc": [],
|
160 |
-
"test_loss": [],
|
161 |
-
"test_acc": []
|
162 |
-
}
|
163 |
-
|
164 |
-
# Make sure model on target device
|
165 |
-
model.to(device)
|
166 |
-
|
167 |
-
# Loop through training and testing steps for a number of epochs
|
168 |
-
for epoch in tqdm(range(epochs)):
|
169 |
-
train_loss, train_acc = train_step(model=model,
|
170 |
-
dataloader=train_dataloader,
|
171 |
-
loss_fn=loss_fn,
|
172 |
-
optimizer=optimizer,
|
173 |
-
device=device)
|
174 |
-
test_loss, test_acc = test_step(model=model,
|
175 |
-
dataloader=test_dataloader,
|
176 |
-
loss_fn=loss_fn,
|
177 |
-
device=device)
|
178 |
-
|
179 |
-
# Print out what's happening
|
180 |
-
print(
|
181 |
-
f"Epoch: {epoch+1} | "
|
182 |
-
f"train_loss: {train_loss:.4f} | "
|
183 |
-
f"train_acc: {train_acc:.4f} | "
|
184 |
-
f"test_loss: {test_loss:.4f} | "
|
185 |
-
f"test_acc: {test_acc:.4f}"
|
186 |
-
)
|
187 |
-
|
188 |
-
# Update results dictionary
|
189 |
-
results["train_loss"].append(train_loss)
|
190 |
-
results["train_acc"].append(train_acc)
|
191 |
-
results["test_loss"].append(test_loss)
|
192 |
-
results["test_acc"].append(test_acc)
|
193 |
-
|
194 |
-
# Return the filled results at the end of the epochs
|
195 |
-
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
helper/helper_functions.py
DELETED
@@ -1,294 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
A series of helper functions used throughout the course.
|
3 |
-
|
4 |
-
If a function gets defined once and could be used over and over, it'll go in here.
|
5 |
-
"""
|
6 |
-
import torch
|
7 |
-
import matplotlib.pyplot as plt
|
8 |
-
import numpy as np
|
9 |
-
|
10 |
-
from torch import nn
|
11 |
-
|
12 |
-
import os
|
13 |
-
import zipfile
|
14 |
-
|
15 |
-
from pathlib import Path
|
16 |
-
|
17 |
-
import requests
|
18 |
-
|
19 |
-
# Walk through an image classification directory and find out how many files (images)
|
20 |
-
# are in each subdirectory.
|
21 |
-
import os
|
22 |
-
|
23 |
-
def walk_through_dir(dir_path):
|
24 |
-
"""
|
25 |
-
Walks through dir_path returning its contents.
|
26 |
-
Args:
|
27 |
-
dir_path (str): target directory
|
28 |
-
|
29 |
-
Returns:
|
30 |
-
A print out of:
|
31 |
-
number of subdiretories in dir_path
|
32 |
-
number of images (files) in each subdirectory
|
33 |
-
name of each subdirectory
|
34 |
-
"""
|
35 |
-
for dirpath, dirnames, filenames in os.walk(dir_path):
|
36 |
-
print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpath}'.")
|
37 |
-
|
38 |
-
def plot_decision_boundary(model: torch.nn.Module, X: torch.Tensor, y: torch.Tensor):
|
39 |
-
"""Plots decision boundaries of model predicting on X in comparison to y.
|
40 |
-
|
41 |
-
Source - https://madewithml.com/courses/foundations/neural-networks/ (with modifications)
|
42 |
-
"""
|
43 |
-
# Put everything to CPU (works better with NumPy + Matplotlib)
|
44 |
-
model.to("cpu")
|
45 |
-
X, y = X.to("cpu"), y.to("cpu")
|
46 |
-
|
47 |
-
# Setup prediction boundaries and grid
|
48 |
-
x_min, x_max = X[:, 0].min() - 0.1, X[:, 0].max() + 0.1
|
49 |
-
y_min, y_max = X[:, 1].min() - 0.1, X[:, 1].max() + 0.1
|
50 |
-
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 101), np.linspace(y_min, y_max, 101))
|
51 |
-
|
52 |
-
# Make features
|
53 |
-
X_to_pred_on = torch.from_numpy(np.column_stack((xx.ravel(), yy.ravel()))).float()
|
54 |
-
|
55 |
-
# Make predictions
|
56 |
-
model.eval()
|
57 |
-
with torch.inference_mode():
|
58 |
-
y_logits = model(X_to_pred_on)
|
59 |
-
|
60 |
-
# Test for multi-class or binary and adjust logits to prediction labels
|
61 |
-
if len(torch.unique(y)) > 2:
|
62 |
-
y_pred = torch.softmax(y_logits, dim=1).argmax(dim=1) # mutli-class
|
63 |
-
else:
|
64 |
-
y_pred = torch.round(torch.sigmoid(y_logits)) # binary
|
65 |
-
|
66 |
-
# Reshape preds and plot
|
67 |
-
y_pred = y_pred.reshape(xx.shape).detach().numpy()
|
68 |
-
plt.contourf(xx, yy, y_pred, cmap=plt.cm.RdYlBu, alpha=0.7)
|
69 |
-
plt.scatter(X[:, 0], X[:, 1], c=y, s=40, cmap=plt.cm.RdYlBu)
|
70 |
-
plt.xlim(xx.min(), xx.max())
|
71 |
-
plt.ylim(yy.min(), yy.max())
|
72 |
-
|
73 |
-
|
74 |
-
# Plot linear data or training and test and predictions (optional)
|
75 |
-
def plot_predictions(
|
76 |
-
train_data, train_labels, test_data, test_labels, predictions=None
|
77 |
-
):
|
78 |
-
"""
|
79 |
-
Plots linear training data and test data and compares predictions.
|
80 |
-
"""
|
81 |
-
plt.figure(figsize=(10, 7))
|
82 |
-
|
83 |
-
# Plot training data in blue
|
84 |
-
plt.scatter(train_data, train_labels, c="b", s=4, label="Training data")
|
85 |
-
|
86 |
-
# Plot test data in green
|
87 |
-
plt.scatter(test_data, test_labels, c="g", s=4, label="Testing data")
|
88 |
-
|
89 |
-
if predictions is not None:
|
90 |
-
# Plot the predictions in red (predictions were made on the test data)
|
91 |
-
plt.scatter(test_data, predictions, c="r", s=4, label="Predictions")
|
92 |
-
|
93 |
-
# Show the legend
|
94 |
-
plt.legend(prop={"size": 14})
|
95 |
-
|
96 |
-
|
97 |
-
# Calculate accuracy (a classification metric)
|
98 |
-
def accuracy_fn(y_true, y_pred):
|
99 |
-
"""Calculates accuracy between truth labels and predictions.
|
100 |
-
|
101 |
-
Args:
|
102 |
-
y_true (torch.Tensor): Truth labels for predictions.
|
103 |
-
y_pred (torch.Tensor): Predictions to be compared to predictions.
|
104 |
-
|
105 |
-
Returns:
|
106 |
-
[torch.float]: Accuracy value between y_true and y_pred, e.g. 78.45
|
107 |
-
"""
|
108 |
-
correct = torch.eq(y_true, y_pred).sum().item()
|
109 |
-
acc = (correct / len(y_pred)) * 100
|
110 |
-
return acc
|
111 |
-
|
112 |
-
|
113 |
-
def print_train_time(start, end, device=None):
|
114 |
-
"""Prints difference between start and end time.
|
115 |
-
|
116 |
-
Args:
|
117 |
-
start (float): Start time of computation (preferred in timeit format).
|
118 |
-
end (float): End time of computation.
|
119 |
-
device ([type], optional): Device that compute is running on. Defaults to None.
|
120 |
-
|
121 |
-
Returns:
|
122 |
-
float: time between start and end in seconds (higher is longer).
|
123 |
-
"""
|
124 |
-
total_time = end - start
|
125 |
-
print(f"\nTrain time on {device}: {total_time:.3f} seconds")
|
126 |
-
return total_time
|
127 |
-
|
128 |
-
|
129 |
-
# Plot loss curves of a model
|
130 |
-
def plot_loss_curves(results):
|
131 |
-
"""Plots training curves of a results dictionary.
|
132 |
-
|
133 |
-
Args:
|
134 |
-
results (dict): dictionary containing list of values, e.g.
|
135 |
-
{"train_loss": [...],
|
136 |
-
"train_acc": [...],
|
137 |
-
"test_loss": [...],
|
138 |
-
"test_acc": [...]}
|
139 |
-
"""
|
140 |
-
loss = results["train_loss"]
|
141 |
-
test_loss = results["test_loss"]
|
142 |
-
|
143 |
-
accuracy = results["train_acc"]
|
144 |
-
test_accuracy = results["test_acc"]
|
145 |
-
|
146 |
-
epochs = range(len(results["train_loss"]))
|
147 |
-
|
148 |
-
plt.figure(figsize=(15, 7))
|
149 |
-
|
150 |
-
# Plot loss
|
151 |
-
plt.subplot(1, 2, 1)
|
152 |
-
plt.plot(epochs, loss, label="train_loss")
|
153 |
-
plt.plot(epochs, test_loss, label="test_loss")
|
154 |
-
plt.title("Loss")
|
155 |
-
plt.xlabel("Epochs")
|
156 |
-
plt.legend()
|
157 |
-
|
158 |
-
# Plot accuracy
|
159 |
-
plt.subplot(1, 2, 2)
|
160 |
-
plt.plot(epochs, accuracy, label="train_accuracy")
|
161 |
-
plt.plot(epochs, test_accuracy, label="test_accuracy")
|
162 |
-
plt.title("Accuracy")
|
163 |
-
plt.xlabel("Epochs")
|
164 |
-
plt.legend()
|
165 |
-
|
166 |
-
|
167 |
-
# Pred and plot image function from notebook 04
|
168 |
-
# See creation: https://www.learnpytorch.io/04_pytorch_custom_datasets/#113-putting-custom-image-prediction-together-building-a-function
|
169 |
-
from typing import List
|
170 |
-
import torchvision
|
171 |
-
|
172 |
-
|
173 |
-
def pred_and_plot_image(
|
174 |
-
model: torch.nn.Module,
|
175 |
-
image_path: str,
|
176 |
-
class_names: List[str] = None,
|
177 |
-
transform=None,
|
178 |
-
device: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
|
179 |
-
):
|
180 |
-
"""Makes a prediction on a target image with a trained model and plots the image.
|
181 |
-
|
182 |
-
Args:
|
183 |
-
model (torch.nn.Module): trained PyTorch image classification model.
|
184 |
-
image_path (str): filepath to target image.
|
185 |
-
class_names (List[str], optional): different class names for target image. Defaults to None.
|
186 |
-
transform (_type_, optional): transform of target image. Defaults to None.
|
187 |
-
device (torch.device, optional): target device to compute on. Defaults to "cuda" if torch.cuda.is_available() else "cpu".
|
188 |
-
|
189 |
-
Returns:
|
190 |
-
Matplotlib plot of target image and model prediction as title.
|
191 |
-
|
192 |
-
Example usage:
|
193 |
-
pred_and_plot_image(model=model,
|
194 |
-
image="some_image.jpeg",
|
195 |
-
class_names=["class_1", "class_2", "class_3"],
|
196 |
-
transform=torchvision.transforms.ToTensor(),
|
197 |
-
device=device)
|
198 |
-
"""
|
199 |
-
|
200 |
-
# 1. Load in image and convert the tensor values to float32
|
201 |
-
target_image = torchvision.io.read_image(str(image_path)).type(torch.float32)
|
202 |
-
|
203 |
-
# 2. Divide the image pixel values by 255 to get them between [0, 1]
|
204 |
-
target_image = target_image / 255.0
|
205 |
-
|
206 |
-
# 3. Transform if necessary
|
207 |
-
if transform:
|
208 |
-
target_image = transform(target_image)
|
209 |
-
|
210 |
-
# 4. Make sure the model is on the target device
|
211 |
-
model.to(device)
|
212 |
-
|
213 |
-
# 5. Turn on model evaluation mode and inference mode
|
214 |
-
model.eval()
|
215 |
-
with torch.inference_mode():
|
216 |
-
# Add an extra dimension to the image
|
217 |
-
target_image = target_image.unsqueeze(dim=0)
|
218 |
-
|
219 |
-
# Make a prediction on image with an extra dimension and send it to the target device
|
220 |
-
target_image_pred = model(target_image.to(device))
|
221 |
-
|
222 |
-
# 6. Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
|
223 |
-
target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
|
224 |
-
|
225 |
-
# 7. Convert prediction probabilities -> prediction labels
|
226 |
-
target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
|
227 |
-
|
228 |
-
# 8. Plot the image alongside the prediction and prediction probability
|
229 |
-
plt.imshow(
|
230 |
-
target_image.squeeze().permute(1, 2, 0)
|
231 |
-
) # make sure it's the right size for matplotlib
|
232 |
-
if class_names:
|
233 |
-
title = f"Pred: {class_names[target_image_pred_label.cpu()]} | Prob: {target_image_pred_probs.max().cpu():.3f}"
|
234 |
-
else:
|
235 |
-
title = f"Pred: {target_image_pred_label} | Prob: {target_image_pred_probs.max().cpu():.3f}"
|
236 |
-
plt.title(title)
|
237 |
-
plt.axis(False)
|
238 |
-
|
239 |
-
def set_seeds(seed: int=42):
|
240 |
-
"""Sets random sets for torch operations.
|
241 |
-
|
242 |
-
Args:
|
243 |
-
seed (int, optional): Random seed to set. Defaults to 42.
|
244 |
-
"""
|
245 |
-
# Set the seed for general torch operations
|
246 |
-
torch.manual_seed(seed)
|
247 |
-
# Set the seed for CUDA torch operations (ones that happen on the GPU)
|
248 |
-
torch.cuda.manual_seed(seed)
|
249 |
-
|
250 |
-
def download_data(source: str,
|
251 |
-
destination: str,
|
252 |
-
remove_source: bool = True) -> Path:
|
253 |
-
"""Downloads a zipped dataset from source and unzips to destination.
|
254 |
-
|
255 |
-
Args:
|
256 |
-
source (str): A link to a zipped file containing data.
|
257 |
-
destination (str): A target directory to unzip data to.
|
258 |
-
remove_source (bool): Whether to remove the source after downloading and extracting.
|
259 |
-
|
260 |
-
Returns:
|
261 |
-
pathlib.Path to downloaded data.
|
262 |
-
|
263 |
-
Example usage:
|
264 |
-
download_data(source="https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip",
|
265 |
-
destination="pizza_steak_sushi")
|
266 |
-
"""
|
267 |
-
# Setup path to data folder
|
268 |
-
data_path = Path("data/")
|
269 |
-
image_path = data_path / destination
|
270 |
-
|
271 |
-
# If the image folder doesn't exist, download it and prepare it...
|
272 |
-
if image_path.is_dir():
|
273 |
-
print(f"[INFO] {image_path} directory exists, skipping download.")
|
274 |
-
else:
|
275 |
-
print(f"[INFO] Did not find {image_path} directory, creating one...")
|
276 |
-
image_path.mkdir(parents=True, exist_ok=True)
|
277 |
-
|
278 |
-
# Download pizza, steak, sushi data
|
279 |
-
target_file = Path(source).name
|
280 |
-
with open(data_path / target_file, "wb") as f:
|
281 |
-
request = requests.get(source)
|
282 |
-
print(f"[INFO] Downloading {target_file} from {source}...")
|
283 |
-
f.write(request.content)
|
284 |
-
|
285 |
-
# Unzip pizza, steak, sushi data
|
286 |
-
with zipfile.ZipFile(data_path / target_file, "r") as zip_ref:
|
287 |
-
print(f"[INFO] Unzipping {target_file} data...")
|
288 |
-
zip_ref.extractall(image_path)
|
289 |
-
|
290 |
-
# Remove .zip file
|
291 |
-
if remove_source:
|
292 |
-
os.remove(data_path / target_file)
|
293 |
-
|
294 |
-
return image_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
helper/model_builder.py
DELETED
@@ -1,56 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Contains PyTorch model code to instantiate a TinyVGG model.
|
3 |
-
"""
|
4 |
-
import torch
|
5 |
-
from torch import nn
|
6 |
-
|
7 |
-
class TinyVGG(nn.Module):
|
8 |
-
"""Creates the TinyVGG architecture.
|
9 |
-
|
10 |
-
Replicates the TinyVGG architecture from the CNN explainer website in PyTorch.
|
11 |
-
See the original architecture here: https://poloclub.github.io/cnn-explainer/
|
12 |
-
|
13 |
-
Args:
|
14 |
-
input_shape: An integer indicating number of input channels.
|
15 |
-
hidden_units: An integer indicating number of hidden units between layers.
|
16 |
-
output_shape: An integer indicating number of output units.
|
17 |
-
"""
|
18 |
-
def __init__(self, input_shape: int, hidden_units: int, output_shape: int) -> None:
|
19 |
-
super().__init__()
|
20 |
-
self.conv_block_1 = nn.Sequential(
|
21 |
-
nn.Conv2d(in_channels=input_shape,
|
22 |
-
out_channels=hidden_units,
|
23 |
-
kernel_size=3,
|
24 |
-
stride=1,
|
25 |
-
padding=0),
|
26 |
-
nn.ReLU(),
|
27 |
-
nn.Conv2d(in_channels=hidden_units,
|
28 |
-
out_channels=hidden_units,
|
29 |
-
kernel_size=3,
|
30 |
-
stride=1,
|
31 |
-
padding=0),
|
32 |
-
nn.ReLU(),
|
33 |
-
nn.MaxPool2d(kernel_size=2,
|
34 |
-
stride=2)
|
35 |
-
)
|
36 |
-
self.conv_block_2 = nn.Sequential(
|
37 |
-
nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=0),
|
38 |
-
nn.ReLU(),
|
39 |
-
nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=0),
|
40 |
-
nn.ReLU(),
|
41 |
-
nn.MaxPool2d(2)
|
42 |
-
)
|
43 |
-
self.classifier = nn.Sequential(
|
44 |
-
nn.Flatten(),
|
45 |
-
# Where did this in_features shape come from?
|
46 |
-
# It's because each layer of our network compresses and changes the shape of our inputs data.
|
47 |
-
nn.Linear(in_features=hidden_units*13*13,
|
48 |
-
out_features=output_shape)
|
49 |
-
)
|
50 |
-
|
51 |
-
def forward(self, x: torch.Tensor):
|
52 |
-
x = self.conv_block_1(x)
|
53 |
-
x = self.conv_block_2(x)
|
54 |
-
x = self.classifier(x)
|
55 |
-
return x
|
56 |
-
# return self.classifier(self.block_2(self.block_1(x))) # <- leverage the benefits of operator fusion
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
helper/predictions.py
DELETED
@@ -1,83 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Utility functions to make predictions.
|
3 |
-
|
4 |
-
Main reference for code creation: https://www.learnpytorch.io/06_pytorch_transfer_learning/#6-make-predictions-on-images-from-the-test-set
|
5 |
-
"""
|
6 |
-
import torch
|
7 |
-
import torchvision
|
8 |
-
from torchvision import transforms
|
9 |
-
import matplotlib.pyplot as plt
|
10 |
-
|
11 |
-
from typing import List, Tuple
|
12 |
-
|
13 |
-
from PIL import Image
|
14 |
-
|
15 |
-
# Set device
|
16 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
-
|
18 |
-
# Predict on a target image with a target model
|
19 |
-
# Function created in: https://www.learnpytorch.io/06_pytorch_transfer_learning/#6-make-predictions-on-images-from-the-test-set
|
20 |
-
def pred_and_plot_image(
|
21 |
-
model: torch.nn.Module,
|
22 |
-
class_names: List[str],
|
23 |
-
image_path: str,
|
24 |
-
image_size: Tuple[int, int] = (224, 224),
|
25 |
-
transform: torchvision.transforms = None,
|
26 |
-
device: torch.device = device,
|
27 |
-
):
|
28 |
-
"""Predicts on a target image with a target model.
|
29 |
-
|
30 |
-
Args:
|
31 |
-
model (torch.nn.Module): A trained (or untrained) PyTorch model to predict on an image.
|
32 |
-
class_names (List[str]): A list of target classes to map predictions to.
|
33 |
-
image_path (str): Filepath to target image to predict on.
|
34 |
-
image_size (Tuple[int, int], optional): Size to transform target image to. Defaults to (224, 224).
|
35 |
-
transform (torchvision.transforms, optional): Transform to perform on image. Defaults to None which uses ImageNet normalization.
|
36 |
-
device (torch.device, optional): Target device to perform prediction on. Defaults to device.
|
37 |
-
"""
|
38 |
-
|
39 |
-
# Open image
|
40 |
-
img = Image.open(image_path)
|
41 |
-
|
42 |
-
# Create transformation for image (if one doesn't exist)
|
43 |
-
if transform is not None:
|
44 |
-
image_transform = transform
|
45 |
-
else:
|
46 |
-
image_transform = transforms.Compose(
|
47 |
-
[
|
48 |
-
transforms.Resize(image_size),
|
49 |
-
transforms.ToTensor(),
|
50 |
-
transforms.Normalize(
|
51 |
-
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
52 |
-
),
|
53 |
-
]
|
54 |
-
)
|
55 |
-
|
56 |
-
### Predict on image ###
|
57 |
-
|
58 |
-
# Make sure the model is on the target device
|
59 |
-
model.to(device)
|
60 |
-
|
61 |
-
# Turn on model evaluation mode and inference mode
|
62 |
-
model.eval()
|
63 |
-
with torch.inference_mode():
|
64 |
-
# Transform and add an extra dimension to image (model requires samples in [batch_size, color_channels, height, width])
|
65 |
-
transformed_image = image_transform(img).unsqueeze(dim=0)
|
66 |
-
|
67 |
-
# Make a prediction on image with an extra dimension and send it to the target device
|
68 |
-
target_image_pred = model(transformed_image.to(device))
|
69 |
-
|
70 |
-
# Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
|
71 |
-
target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
|
72 |
-
|
73 |
-
# Convert prediction probabilities -> prediction labels
|
74 |
-
target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
|
75 |
-
|
76 |
-
# Plot image with predicted label and probability
|
77 |
-
plt.figure()
|
78 |
-
plt.imshow(img)
|
79 |
-
plt.title(
|
80 |
-
f"Pred: {class_names[target_image_pred_label]} | Prob: {target_image_pred_probs.max():.3f}"
|
81 |
-
)
|
82 |
-
plt.axis(False)
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
helper/train.py
DELETED
@@ -1,62 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Trains a PyTorch image classification model using device-agnostic code.
|
3 |
-
"""
|
4 |
-
|
5 |
-
import os
|
6 |
-
import torch
|
7 |
-
import data_setup, engine, model_builder, utils
|
8 |
-
|
9 |
-
from torchvision import transforms
|
10 |
-
|
11 |
-
# Setup hyperparameters
|
12 |
-
NUM_EPOCHS = 5
|
13 |
-
BATCH_SIZE = 32
|
14 |
-
HIDDEN_UNITS = 10
|
15 |
-
LEARNING_RATE = 0.001
|
16 |
-
|
17 |
-
# Setup directories
|
18 |
-
train_dir = "data/pizza_steak_sushi/train"
|
19 |
-
test_dir = "data/pizza_steak_sushi/test"
|
20 |
-
|
21 |
-
# Setup target device
|
22 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
23 |
-
|
24 |
-
# Create transforms
|
25 |
-
data_transform = transforms.Compose([
|
26 |
-
transforms.Resize((64, 64)),
|
27 |
-
transforms.ToTensor()
|
28 |
-
])
|
29 |
-
|
30 |
-
# Create DataLoaders with help from data_setup.py
|
31 |
-
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
|
32 |
-
train_dir=train_dir,
|
33 |
-
test_dir=test_dir,
|
34 |
-
transform=data_transform,
|
35 |
-
batch_size=BATCH_SIZE
|
36 |
-
)
|
37 |
-
|
38 |
-
# Create model with help from model_builder.py
|
39 |
-
model = model_builder.TinyVGG(
|
40 |
-
input_shape=3,
|
41 |
-
hidden_units=HIDDEN_UNITS,
|
42 |
-
output_shape=len(class_names)
|
43 |
-
).to(device)
|
44 |
-
|
45 |
-
# Set loss and optimizer
|
46 |
-
loss_fn = torch.nn.CrossEntropyLoss()
|
47 |
-
optimizer = torch.optim.Adam(model.parameters(),
|
48 |
-
lr=LEARNING_RATE)
|
49 |
-
|
50 |
-
# Start training with help from engine.py
|
51 |
-
engine.train(model=model,
|
52 |
-
train_dataloader=train_dataloader,
|
53 |
-
test_dataloader=test_dataloader,
|
54 |
-
loss_fn=loss_fn,
|
55 |
-
optimizer=optimizer,
|
56 |
-
epochs=NUM_EPOCHS,
|
57 |
-
device=device)
|
58 |
-
|
59 |
-
# Save the model with help from utils.py
|
60 |
-
utils.save_model(model=model,
|
61 |
-
target_dir="models",
|
62 |
-
model_name="05_going_modular_script_mode_tinyvgg_model.pth")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
helper/utils.py
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Contains various utility functions for PyTorch model training and saving.
|
3 |
-
"""
|
4 |
-
import torch
|
5 |
-
from pathlib import Path
|
6 |
-
|
7 |
-
def save_model(model: torch.nn.Module,
|
8 |
-
target_dir: str,
|
9 |
-
model_name: str):
|
10 |
-
"""Saves a PyTorch model to a target directory.
|
11 |
-
|
12 |
-
Args:
|
13 |
-
model: A target PyTorch model to save.
|
14 |
-
target_dir: A directory for saving the model to.
|
15 |
-
model_name: A filename for the saved model. Should include
|
16 |
-
either ".pth" or ".pt" as the file extension.
|
17 |
-
|
18 |
-
Example usage:
|
19 |
-
save_model(model=model_0,
|
20 |
-
target_dir="models",
|
21 |
-
model_name="05_going_modular_tingvgg_model.pth")
|
22 |
-
"""
|
23 |
-
# Create target directory
|
24 |
-
target_dir_path = Path(target_dir)
|
25 |
-
target_dir_path.mkdir(parents=True,
|
26 |
-
exist_ok=True)
|
27 |
-
|
28 |
-
# Create model save path
|
29 |
-
assert model_name.endswith(".pth") or model_name.endswith(".pt"), "model_name should end with '.pt' or '.pth'"
|
30 |
-
model_save_path = target_dir_path / model_name
|
31 |
-
|
32 |
-
# Save the model state_dict()
|
33 |
-
print(f"[INFO] Saving model to: {model_save_path}")
|
34 |
-
torch.save(obj=model.state_dict(),
|
35 |
-
f=model_save_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.py
DELETED
@@ -1,68 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torchvision
|
3 |
-
from torch import nn
|
4 |
-
from torchvision import transforms
|
5 |
-
|
6 |
-
|
7 |
-
def create_model(num_classes: int = 32,
|
8 |
-
seed: int = 42):
|
9 |
-
"""Creates a feature extractor model and transforms.
|
10 |
-
|
11 |
-
Args:
|
12 |
-
num_classes (int, optional): number of classes in the classifier head.
|
13 |
-
Defaults to 32.
|
14 |
-
seed (int, optional): random seed value. Defaults to 42.
|
15 |
-
|
16 |
-
Returns:
|
17 |
-
model (torch.nn.Module): vit feature extractor model.
|
18 |
-
transforms (torchvision.transforms): vit image transforms.
|
19 |
-
"""
|
20 |
-
IMG_SIZE = 28
|
21 |
-
model_transforms = transforms.Compose([
|
22 |
-
transforms.Resize((IMG_SIZE, IMG_SIZE)),
|
23 |
-
transforms.Grayscale(num_output_channels=1),
|
24 |
-
transforms.ToTensor()])
|
25 |
-
|
26 |
-
# Create a convolutional neural network
|
27 |
-
class Model(nn.Module):
|
28 |
-
def __init__(self, input_shape: int, hidden_units: int, output_shape: int):
|
29 |
-
super().__init__()
|
30 |
-
self.block_1 = nn.Sequential(
|
31 |
-
nn.Conv2d(in_channels=input_shape,
|
32 |
-
out_channels=hidden_units,
|
33 |
-
kernel_size=3, # how big is the square that's going over the image?
|
34 |
-
stride=1, # default
|
35 |
-
padding=1), # options = "valid" (no padding) or "same" (output has same shape as input) or int for specific number
|
36 |
-
nn.ReLU(),
|
37 |
-
nn.Conv2d(in_channels=hidden_units,
|
38 |
-
out_channels=hidden_units,
|
39 |
-
kernel_size=3,
|
40 |
-
stride=1,
|
41 |
-
padding=1),
|
42 |
-
nn.ReLU(),
|
43 |
-
nn.MaxPool2d(kernel_size=2,
|
44 |
-
stride=2) # default stride value is same as kernel_size
|
45 |
-
)
|
46 |
-
self.block_2 = nn.Sequential(
|
47 |
-
nn.Conv2d(hidden_units, hidden_units, 3, padding=1),
|
48 |
-
nn.ReLU(),
|
49 |
-
nn.Conv2d(hidden_units, hidden_units, 3, padding=1),
|
50 |
-
nn.ReLU(),
|
51 |
-
nn.MaxPool2d(2)
|
52 |
-
)
|
53 |
-
self.classifier = nn.Sequential(
|
54 |
-
nn.Flatten(),
|
55 |
-
nn.Linear(in_features=hidden_units*7*7,
|
56 |
-
out_features=output_shape)
|
57 |
-
)
|
58 |
-
|
59 |
-
def forward(self, x: torch.Tensor):
|
60 |
-
# x = self.block_1(x)
|
61 |
-
# print(x.shape)
|
62 |
-
# x = self.block_2(x)
|
63 |
-
# print(x.shape)
|
64 |
-
# x = self.classifier(x)
|
65 |
-
# print(x.shape)
|
66 |
-
x = self.classifier(self.block_2(self.block_1(x)))
|
67 |
-
return x
|
68 |
-
return Model, model_transforms
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/2model28.h5
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:9ec76d2434f8cb24f08df038ecb93ca91b15db9110df2ba340488c24c123b64b
|
3 |
-
size 14480208
|
|
|
|
|
|
|
|
solver.py
CHANGED
@@ -3,32 +3,25 @@ from PIL import ImageFont, ImageDraw, Image
|
|
3 |
import cv2
|
4 |
import numpy as np
|
5 |
import os
|
6 |
-
import model
|
7 |
import src.solve as solve
|
8 |
from typing import Tuple
|
9 |
import pytesseract
|
10 |
import re
|
11 |
import shutil
|
12 |
-
import torch
|
13 |
-
import torch.nn as nn
|
14 |
-
import torchvision
|
15 |
-
from torchvision import transforms
|
16 |
-
from model import create_model
|
17 |
import tensorflow as tf
|
18 |
from tensorflow import keras
|
|
|
19 |
|
20 |
|
21 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
22 |
-
|
23 |
# Setup class names
|
24 |
-
with open("class_names.txt", "r") as f:
|
25 |
-
class_names = [names.strip() for names in
|
26 |
-
|
27 |
|
28 |
-
model1 = tf.keras.models.load_model('model/
|
29 |
-
model2 = tf.keras.models.load_model('model/
|
30 |
-
model3 = tf.keras.models.load_model('model/
|
31 |
|
|
|
32 |
# Borrar el directorio de imagenes
|
33 |
folder = 'output'
|
34 |
for filename in os.listdir(folder):
|
@@ -52,18 +45,25 @@ for filename in os.listdir(folder):
|
|
52 |
print('Failed to delete %s. Reason: %s' % (file_path, e))
|
53 |
|
54 |
|
55 |
-
def get_words(img)
|
56 |
-
print(type(img))
|
57 |
# str to filepath
|
58 |
img = Image.open(img)
|
59 |
# Display image
|
60 |
-
#img.show()
|
61 |
text = pytesseract.image_to_string(img, lang="spa+eng", config="--psm 11")
|
62 |
text = text.upper()
|
63 |
text = re.split('\W+', text)
|
64 |
text.pop()
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
return text
|
66 |
|
|
|
67 |
def getmat(listaCuadrados, filas, columnas):
|
68 |
matrix = [[0 for i in range(columnas)] for j in range(filas)]
|
69 |
matrixT = [[0 for i in range(columnas)] for j in range(filas)]
|
@@ -94,40 +94,50 @@ def get_colums_and_rows(listaCuadrados):
|
|
94 |
columnas = columnas + 1
|
95 |
return filas, columnas
|
96 |
|
|
|
97 |
def read_board(img, words):
|
98 |
-
|
99 |
# str to filepath
|
100 |
img = Image.open(img)
|
101 |
# Display image
|
102 |
img.show()
|
103 |
-
#Print words
|
104 |
-
print("Palabras a buscar: ",
|
|
|
105 |
|
106 |
def solve_puzzle(img, words):
|
107 |
-
#print(type(img))
|
108 |
# str to filepath
|
109 |
-
#
|
|
|
|
|
110 |
# Pil to opencv compatible
|
111 |
-
pil_image = Image.open(img).convert('RGB')
|
112 |
-
open_cv_image = np.array(pil_image)
|
113 |
-
# Convert RGB to BGR
|
114 |
-
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
115 |
# Display image
|
116 |
-
#Print words
|
117 |
img = open_cv_image
|
118 |
-
|
|
|
|
|
|
|
|
|
119 |
imgc = img.copy()
|
120 |
imgsol = img.copy()
|
121 |
-
imgc = cv2.cvtColor(imgc, cv2.
|
122 |
imgc = np.invert(imgc)
|
123 |
-
gray = cv2.cvtColor(img, cv2.
|
124 |
-
|
125 |
# save the blurred image
|
126 |
-
|
127 |
# display blurred image
|
128 |
-
threshten = cv2.threshold(
|
|
|
129 |
thresh = cv2.adaptiveThreshold(threshten, 255, 1, 1, 11, 2)
|
130 |
-
contours, hierarchy = cv2.findContours(
|
|
|
131 |
# Draw contours and save the image
|
132 |
|
133 |
characters = np.array([
|
@@ -164,7 +174,6 @@ def solve_puzzle(img, words):
|
|
164 |
cv2.rectangle(img, (x0, y0), (x1, y1), (0, 255, 0), 2)
|
165 |
img2 = imgc[y0:y1, x0:x1]
|
166 |
img2 = cv2.resize(img2, (28, 28))
|
167 |
-
img2 = cv2.resize(img2,(28,28))
|
168 |
img_array = img2.reshape(1, 28, 28, 1)
|
169 |
prediction1 = np.argmax(model1.predict(img_array))
|
170 |
prediction2 = np.argmax(model2.predict(img_array))
|
@@ -175,11 +184,12 @@ def solve_puzzle(img, words):
|
|
175 |
elif prediction1 == prediction2:
|
176 |
pred = prediction1
|
177 |
elif prediction2 == prediction3:
|
178 |
-
|
179 |
elif prediction1 == prediction3:
|
180 |
-
|
181 |
else:
|
182 |
pred = 32
|
|
|
183 |
contCuadrados["anchura"] = x
|
184 |
contCuadrados["altura"] = y
|
185 |
contCuadrados["centrox"] = (x + x + w)/2
|
@@ -191,20 +201,30 @@ def solve_puzzle(img, words):
|
|
191 |
draw.text(((x0+x1)/2, y0-10),
|
192 |
characters[pred], font=fnt, fill=(255, 0, 0, 0))
|
193 |
img = np.array(img_pil)
|
194 |
-
|
195 |
filas, columnas = get_colums_and_rows(listaCuadrados)
|
196 |
-
|
|
|
197 |
matrix, matrixT = getmat(listaCuadrados, filas, columnas)
|
198 |
palabrasxy = []
|
199 |
|
200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
image_new = imgsol.copy()
|
202 |
overlay = imgsol.copy()
|
203 |
-
import random
|
204 |
index = 0
|
205 |
index2 = 0
|
206 |
for i in words:
|
207 |
-
|
208 |
xy_positionsvec, find = solve.find_word(matrix, i)
|
209 |
if find:
|
210 |
palabrasxy.append(xy_positionsvec)
|
@@ -212,8 +232,8 @@ def solve_puzzle(img, words):
|
|
212 |
# print(len(xy_positionsvec))
|
213 |
xy = xy_positionsvec[0]
|
214 |
xy2 = xy_positionsvec[len(xy_positionsvec)-1]
|
215 |
-
#print(xy["x"], " ",xy["y"])
|
216 |
-
#print(xy2["x"], " ",xy2["y"])
|
217 |
coordreal = matrixT[xy["x"]][xy["y"]]
|
218 |
coordreal2 = matrixT[xy2["x"]][xy2["y"]]
|
219 |
centrox = coordreal["centrox"]
|
@@ -232,11 +252,15 @@ def solve_puzzle(img, words):
|
|
232 |
cv2.line(overlay2, (centrox, centroy), (centrox2, centroy2), color,
|
233 |
thickness=int(abs(coordreal["altura"] - coordreal["centroy"])*2))
|
234 |
image_word = cv2.addWeighted(overlay2, 0.4, image_new, 1 - 0.4, 0)
|
235 |
-
cv2.imwrite("
|
|
|
|
|
|
|
236 |
|
237 |
index += 1
|
238 |
index2 += 1
|
239 |
alpha = 0.4 # Transparency factor
|
240 |
image_new = cv2.addWeighted(overlay, alpha, image_new, 1 - alpha, 0)
|
241 |
-
|
242 |
-
return
|
|
|
|
3 |
import cv2
|
4 |
import numpy as np
|
5 |
import os
|
|
|
6 |
import src.solve as solve
|
7 |
from typing import Tuple
|
8 |
import pytesseract
|
9 |
import re
|
10 |
import shutil
|
|
|
|
|
|
|
|
|
|
|
11 |
import tensorflow as tf
|
12 |
from tensorflow import keras
|
13 |
+
import random
|
14 |
|
15 |
|
|
|
|
|
16 |
# Setup class names
|
17 |
+
with open("class_names.txt", "r") as f: # reading them in from class_names.txt
|
18 |
+
class_names = [names.strip() for names in f.readlines()]
|
|
|
19 |
|
20 |
+
model1 = tf.keras.models.load_model('model/model30.h5')
|
21 |
+
model2 = tf.keras.models.load_model('model/model15.h5')
|
22 |
+
model3 = tf.keras.models.load_model('model/model2.h5')
|
23 |
|
24 |
+
palabras_1 = []
|
25 |
# Borrar el directorio de imagenes
|
26 |
folder = 'output'
|
27 |
for filename in os.listdir(folder):
|
|
|
45 |
print('Failed to delete %s. Reason: %s' % (file_path, e))
|
46 |
|
47 |
|
48 |
+
def get_words(img):
|
49 |
+
#print(type(img))
|
50 |
# str to filepath
|
51 |
img = Image.open(img)
|
52 |
# Display image
|
53 |
+
# img.show()
|
54 |
text = pytesseract.image_to_string(img, lang="spa+eng", config="--psm 11")
|
55 |
text = text.upper()
|
56 |
text = re.split('\W+', text)
|
57 |
text.pop()
|
58 |
+
#palabras_1 = text
|
59 |
+
#print(palabras_1)
|
60 |
+
# array to string text
|
61 |
+
text = ' '.join(text)
|
62 |
+
# add comma to text
|
63 |
+
text = text.replace(" ", ",")
|
64 |
return text
|
65 |
|
66 |
+
|
67 |
def getmat(listaCuadrados, filas, columnas):
|
68 |
matrix = [[0 for i in range(columnas)] for j in range(filas)]
|
69 |
matrixT = [[0 for i in range(columnas)] for j in range(filas)]
|
|
|
94 |
columnas = columnas + 1
|
95 |
return filas, columnas
|
96 |
|
97 |
+
|
98 |
def read_board(img, words):
|
99 |
+
#(type(img))
|
100 |
# str to filepath
|
101 |
img = Image.open(img)
|
102 |
# Display image
|
103 |
img.show()
|
104 |
+
# Print words
|
105 |
+
#print("Palabras a buscar: ", palabras_1)
|
106 |
+
|
107 |
|
108 |
def solve_puzzle(img, words):
|
109 |
+
# print(type(img))
|
110 |
# str to filepath
|
111 |
+
#print(type(words))
|
112 |
+
#print(words)
|
113 |
+
# img = Image.open(img)
|
114 |
# Pil to opencv compatible
|
115 |
+
pil_image = Image.open(img).convert('RGB')
|
116 |
+
open_cv_image = np.array(pil_image)
|
117 |
+
# Convert RGB to BGR
|
118 |
+
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
119 |
# Display image
|
120 |
+
# Print words
|
121 |
img = open_cv_image
|
122 |
+
# string to array
|
123 |
+
words = words.split(",")
|
124 |
+
# remove last ,
|
125 |
+
#print(words)
|
126 |
+
|
127 |
imgc = img.copy()
|
128 |
imgsol = img.copy()
|
129 |
+
imgc = cv2.cvtColor(imgc, cv2.COLOR_RGB2GRAY)
|
130 |
imgc = np.invert(imgc)
|
131 |
+
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
132 |
+
blur = cv2.GaussianBlur(gray, (5, 5), 0)
|
133 |
# save the blurred image
|
134 |
+
cv2.imwrite("output/blur.png", blur)
|
135 |
# display blurred image
|
136 |
+
threshten = cv2.threshold(
|
137 |
+
blur, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
|
138 |
thresh = cv2.adaptiveThreshold(threshten, 255, 1, 1, 11, 2)
|
139 |
+
contours, hierarchy = cv2.findContours(
|
140 |
+
thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
141 |
# Draw contours and save the image
|
142 |
|
143 |
characters = np.array([
|
|
|
174 |
cv2.rectangle(img, (x0, y0), (x1, y1), (0, 255, 0), 2)
|
175 |
img2 = imgc[y0:y1, x0:x1]
|
176 |
img2 = cv2.resize(img2, (28, 28))
|
|
|
177 |
img_array = img2.reshape(1, 28, 28, 1)
|
178 |
prediction1 = np.argmax(model1.predict(img_array))
|
179 |
prediction2 = np.argmax(model2.predict(img_array))
|
|
|
184 |
elif prediction1 == prediction2:
|
185 |
pred = prediction1
|
186 |
elif prediction2 == prediction3:
|
187 |
+
pred = prediction2
|
188 |
elif prediction1 == prediction3:
|
189 |
+
pred = prediction3
|
190 |
else:
|
191 |
pred = 32
|
192 |
+
#print(characters[pred])
|
193 |
contCuadrados["anchura"] = x
|
194 |
contCuadrados["altura"] = y
|
195 |
contCuadrados["centrox"] = (x + x + w)/2
|
|
|
201 |
draw.text(((x0+x1)/2, y0-10),
|
202 |
characters[pred], font=fnt, fill=(255, 0, 0, 0))
|
203 |
img = np.array(img_pil)
|
204 |
+
#cv2.imwrite("output/Tablero_Labels.png", img)
|
205 |
filas, columnas = get_colums_and_rows(listaCuadrados)
|
206 |
+
# print listaCuadrados
|
207 |
+
# print(listaCuadrados)
|
208 |
matrix, matrixT = getmat(listaCuadrados, filas, columnas)
|
209 |
palabrasxy = []
|
210 |
|
211 |
+
# print()
|
212 |
+
# print("Palabras a buscar:")
|
213 |
+
# for i in words:
|
214 |
+
# print(i)
|
215 |
+
# print()
|
216 |
+
# for i in range(filas):
|
217 |
+
# for j in range(columnas):
|
218 |
+
# print(matrix[i][j], end = " ")
|
219 |
+
# print()
|
220 |
+
# print()
|
221 |
+
# print()
|
222 |
image_new = imgsol.copy()
|
223 |
overlay = imgsol.copy()
|
|
|
224 |
index = 0
|
225 |
index2 = 0
|
226 |
for i in words:
|
227 |
+
#(i)
|
228 |
xy_positionsvec, find = solve.find_word(matrix, i)
|
229 |
if find:
|
230 |
palabrasxy.append(xy_positionsvec)
|
|
|
232 |
# print(len(xy_positionsvec))
|
233 |
xy = xy_positionsvec[0]
|
234 |
xy2 = xy_positionsvec[len(xy_positionsvec)-1]
|
235 |
+
# print(xy["x"], " ",xy["y"])
|
236 |
+
# print(xy2["x"], " ",xy2["y"])
|
237 |
coordreal = matrixT[xy["x"]][xy["y"]]
|
238 |
coordreal2 = matrixT[xy2["x"]][xy2["y"]]
|
239 |
centrox = coordreal["centrox"]
|
|
|
252 |
cv2.line(overlay2, (centrox, centroy), (centrox2, centroy2), color,
|
253 |
thickness=int(abs(coordreal["altura"] - coordreal["centroy"])*2))
|
254 |
image_word = cv2.addWeighted(overlay2, 0.4, image_new, 1 - 0.4, 0)
|
255 |
+
cv2.imwrite("wordsPuzzle/" + words[index2] + ".jpg", image_word)
|
256 |
+
# append the image into a numpy array
|
257 |
+
|
258 |
+
#print(words[index2])
|
259 |
|
260 |
index += 1
|
261 |
index2 += 1
|
262 |
alpha = 0.4 # Transparency factor
|
263 |
image_new = cv2.addWeighted(overlay, alpha, image_new, 1 - alpha, 0)
|
264 |
+
cv2.imwrite("output/Tablero_solucion.png", image_new)
|
265 |
+
# return the images in wordsPuzzle folder as numpy arrays
|
266 |
+
return image_new
|
src/__pycache__/solve.cpython-310.pyc
DELETED
Binary file (2.93 kB)
|
|
src/__pycache__/solve.cpython-39.pyc
DELETED
Binary file (2.96 kB)
|
|
src/__pycache__/tesseract.cpython-310.pyc
DELETED
Binary file (763 Bytes)
|
|
src/__pycache__/tesseract.cpython-39.pyc
DELETED
Binary file (940 Bytes)
|
|
src/solve.py
CHANGED
@@ -18,7 +18,7 @@ def find_word (wordsearch, word):
|
|
18 |
# Word foundf
|
19 |
return xy_positionsvec,True
|
20 |
# Word not found
|
21 |
-
print(word, ' No encontrada')
|
22 |
return xy_positionsvec,False
|
23 |
|
24 |
def check_start (wordsearch, word, start_pos):
|
@@ -40,9 +40,9 @@ def check_dir (wordsearch, word, start_pos, dir):
|
|
40 |
while (chars_match(found_chars, word)):
|
41 |
if (len(found_chars) == len(word)):
|
42 |
# If found all characters and all characters found are correct, then word has been found
|
43 |
-
print('')
|
44 |
-
print(word, ' Encontrada en:')
|
45 |
-
print('')
|
46 |
# Draw wordsearch on command line. Display found characters and '-' everywhere else
|
47 |
index =1
|
48 |
for x in range(0, len(wordsearch)):
|
@@ -68,8 +68,8 @@ def check_dir (wordsearch, word, start_pos, dir):
|
|
68 |
else:
|
69 |
line = line + " -"
|
70 |
|
71 |
-
print(line)
|
72 |
-
print('')
|
73 |
return True, xy_positionsvec
|
74 |
# Have not found enough letters so look at the next one
|
75 |
current_pos = [current_pos[0] + dir[0], current_pos[1] + dir[1]]
|
|
|
18 |
# Word foundf
|
19 |
return xy_positionsvec,True
|
20 |
# Word not found
|
21 |
+
#print(word, ' No encontrada')
|
22 |
return xy_positionsvec,False
|
23 |
|
24 |
def check_start (wordsearch, word, start_pos):
|
|
|
40 |
while (chars_match(found_chars, word)):
|
41 |
if (len(found_chars) == len(word)):
|
42 |
# If found all characters and all characters found are correct, then word has been found
|
43 |
+
#print('')
|
44 |
+
#print(word, ' Encontrada en:')
|
45 |
+
#print('')
|
46 |
# Draw wordsearch on command line. Display found characters and '-' everywhere else
|
47 |
index =1
|
48 |
for x in range(0, len(wordsearch)):
|
|
|
68 |
else:
|
69 |
line = line + " -"
|
70 |
|
71 |
+
#print(line)
|
72 |
+
#print('')
|
73 |
return True, xy_positionsvec
|
74 |
# Have not found enough letters so look at the next one
|
75 |
current_pos = [current_pos[0] + dir[0], current_pos[1] + dir[1]]
|
words/descarga.png
DELETED
Binary file (224 kB)
|
|
words/words1.png
DELETED
Binary file (73 kB)
|
|