LiuPengNGP commited on
Commit
10cdadf
·
verified ·
1 Parent(s): 915d33b

Upload 8 files

Browse files
model/___init__.py ADDED
File without changes
model/app_utils.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: app_utils.py
3
+ Author: Elena Ryumina and Dmitry Ryumin
4
+ Description: This module contains utility functions for facial expression recognition application.
5
+ License: MIT License
6
+ """
7
+
8
+ import torch
9
+ import numpy as np
10
+ import mediapipe as mp
11
+ from PIL import Image
12
+ import cv2
13
+ from pytorch_grad_cam.utils.image import show_cam_on_image
14
+
15
+ # Importing necessary components for the Gradio app
16
+ from app.model import pth_model_static, pth_model_dynamic, cam, pth_processing
17
+ from app.face_utils import get_box, display_info
18
+ from app.config import DICT_EMO, config_data
19
+ from app.plot import statistics_plot
20
+
21
+ mp_face_mesh = mp.solutions.face_mesh
22
+
23
+
24
+ def preprocess_image_and_predict(inp):
25
+ inp = np.array(inp)
26
+
27
+ if inp is None:
28
+ return None, None, None
29
+
30
+ try:
31
+ h, w = inp.shape[:2]
32
+ except Exception:
33
+ return None, None, None
34
+
35
+ with mp_face_mesh.FaceMesh(
36
+ max_num_faces=1,
37
+ refine_landmarks=False,
38
+ min_detection_confidence=0.5,
39
+ min_tracking_confidence=0.5,
40
+ ) as face_mesh:
41
+ results = face_mesh.process(inp)
42
+ if results.multi_face_landmarks:
43
+ for fl in results.multi_face_landmarks:
44
+ startX, startY, endX, endY = get_box(fl, w, h)
45
+ cur_face = inp[startY:endY, startX:endX]
46
+ cur_face_n = pth_processing(Image.fromarray(cur_face))
47
+ with torch.no_grad():
48
+ prediction = (
49
+ torch.nn.functional.softmax(pth_model_static(cur_face_n), dim=1)
50
+ .detach()
51
+ .numpy()[0]
52
+ )
53
+ confidences = {DICT_EMO[i]: float(prediction[i]) for i in range(7)}
54
+ grayscale_cam = cam(input_tensor=cur_face_n)
55
+ grayscale_cam = grayscale_cam[0, :]
56
+ cur_face_hm = cv2.resize(cur_face,(224,224))
57
+ cur_face_hm = np.float32(cur_face_hm) / 255
58
+ heatmap = show_cam_on_image(cur_face_hm, grayscale_cam, use_rgb=True)
59
+
60
+ return cur_face, heatmap, confidences
61
+
62
+ else:
63
+ return None, None, None
64
+
65
+ def preprocess_video_and_predict(video):
66
+
67
+ cap = cv2.VideoCapture(video)
68
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
69
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
70
+ fps = np.round(cap.get(cv2.CAP_PROP_FPS))
71
+
72
+ path_save_video_face = 'result_face.mp4'
73
+ vid_writer_face = cv2.VideoWriter(path_save_video_face, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224))
74
+
75
+ path_save_video_hm = 'result_hm.mp4'
76
+ vid_writer_hm = cv2.VideoWriter(path_save_video_hm, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224))
77
+
78
+ lstm_features = []
79
+ count_frame = 1
80
+ count_face = 0
81
+ probs = []
82
+ frames = []
83
+ last_output = None
84
+ last_heatmap = None
85
+ cur_face = None
86
+
87
+ with mp_face_mesh.FaceMesh(
88
+ max_num_faces=1,
89
+ refine_landmarks=False,
90
+ min_detection_confidence=0.5,
91
+ min_tracking_confidence=0.5) as face_mesh:
92
+
93
+ while cap.isOpened():
94
+ _, frame = cap.read()
95
+ if frame is None: break
96
+
97
+ frame_copy = frame.copy()
98
+ frame_copy.flags.writeable = False
99
+ frame_copy = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)
100
+ results = face_mesh.process(frame_copy)
101
+ frame_copy.flags.writeable = True
102
+
103
+ if results.multi_face_landmarks:
104
+ for fl in results.multi_face_landmarks:
105
+ startX, startY, endX, endY = get_box(fl, w, h)
106
+ cur_face = frame_copy[startY:endY, startX: endX]
107
+
108
+ if count_face%config_data.FRAME_DOWNSAMPLING == 0:
109
+ cur_face_copy = pth_processing(Image.fromarray(cur_face))
110
+ with torch.no_grad():
111
+ features = torch.nn.functional.relu(pth_model_static.extract_features(cur_face_copy)).detach().numpy()
112
+
113
+ grayscale_cam = cam(input_tensor=cur_face_copy)
114
+ grayscale_cam = grayscale_cam[0, :]
115
+ cur_face_hm = cv2.resize(cur_face,(224,224), interpolation = cv2.INTER_AREA)
116
+ cur_face_hm = np.float32(cur_face_hm) / 255
117
+ heatmap = show_cam_on_image(cur_face_hm, grayscale_cam, use_rgb=False)
118
+ last_heatmap = heatmap
119
+
120
+ if len(lstm_features) == 0:
121
+ lstm_features = [features]*10
122
+ else:
123
+ lstm_features = lstm_features[1:] + [features]
124
+
125
+ lstm_f = torch.from_numpy(np.vstack(lstm_features))
126
+ lstm_f = torch.unsqueeze(lstm_f, 0)
127
+ with torch.no_grad():
128
+ output = pth_model_dynamic(lstm_f).detach().numpy()
129
+ last_output = output
130
+
131
+ if count_face == 0:
132
+ count_face += 1
133
+
134
+ else:
135
+ if last_output is not None:
136
+ output = last_output
137
+ heatmap = last_heatmap
138
+
139
+ elif last_output is None:
140
+ output = np.empty((1, 7))
141
+ output[:] = np.nan
142
+
143
+ probs.append(output[0])
144
+ frames.append(count_frame)
145
+ else:
146
+ if last_output is not None:
147
+ lstm_features = []
148
+ empty = np.empty((7))
149
+ empty[:] = np.nan
150
+ probs.append(empty)
151
+ frames.append(count_frame)
152
+
153
+ if cur_face is not None:
154
+ heatmap_f = display_info(heatmap, 'Frame: {}'.format(count_frame), box_scale=.3)
155
+
156
+ cur_face = cv2.cvtColor(cur_face, cv2.COLOR_RGB2BGR)
157
+ cur_face = cv2.resize(cur_face, (224,224), interpolation = cv2.INTER_AREA)
158
+ cur_face = display_info(cur_face, 'Frame: {}'.format(count_frame), box_scale=.3)
159
+ vid_writer_face.write(cur_face)
160
+ vid_writer_hm.write(heatmap_f)
161
+
162
+ count_frame += 1
163
+ if count_face != 0:
164
+ count_face += 1
165
+
166
+ vid_writer_face.release()
167
+ vid_writer_hm.release()
168
+
169
+ stat = statistics_plot(frames, probs)
170
+
171
+ if not stat:
172
+ return None, None, None, None
173
+
174
+ return video, path_save_video_face, path_save_video_hm, stat
model/authors.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: authors.py
3
+ Author: Elena Ryumina and Dmitry Ryumin
4
+ Description: About the authors.
5
+ License: MIT License
6
+ """
7
+
8
+
9
+ AUTHORS = """
10
+ Authors: [Elena Ryumina](https://github.com/ElenaRyumina), [Dmitry Ryumin](https://github.com/DmitryRyumin), [Denis Dresvyanskiy](https://www.uni-ulm.de/en/nt/staff/research-assistants/dresvyanskiy/), [Maxim Markitantov](https://hci.nw.ru/en/employees/10) and [Alexey Karpov](https://hci.nw.ru/en/employees/1)
11
+
12
+ Authorship contribution:
13
+
14
+ App developers: ``Elena Ryumina`` and ``Dmitry Ryumin``
15
+
16
+ Methodology developers: ``Elena Ryumina``, ``Denis Dresvyanskiy`` and ``Alexey Karpov``
17
+
18
+ Model developer: ``Elena Ryumina``
19
+
20
+ TensorFlow to PyTorch model converters: ``Maxim Markitantov`` and ``Elena Ryumina``
21
+
22
+ Citation
23
+
24
+ If you are using EMO-AffectNetModel in your research, please consider to cite research [paper](https://www.sciencedirect.com/science/article/pii/S0925231222012656). Here is an example of BibTeX entry:
25
+
26
+ <div class="highlight highlight-text-bibtex notranslate position-relative overflow-auto" dir="auto"><pre><span class="pl-k">@article</span>{<span class="pl-en">RYUMINA2022</span>,
27
+ <span class="pl-s">title</span> = <span class="pl-s"><span class="pl-pds">{</span>In Search of a Robust Facial Expressions Recognition Model: A Large-Scale Visual Cross-Corpus Study<span class="pl-pds">}</span></span>,
28
+ <span class="pl-s">author</span> = <span class="pl-s"><span class="pl-pds">{</span>Elena Ryumina and Denis Dresvyanskiy and Alexey Karpov<span class="pl-pds">}</span></span>,
29
+ <span class="pl-s">journal</span> = <span class="pl-s"><span class="pl-pds">{</span>Neurocomputing<span class="pl-pds">}</span></span>,
30
+ <span class="pl-s">year</span> = <span class="pl-s"><span class="pl-pds">{</span>2022<span class="pl-pds">}</span></span>,
31
+ <span class="pl-s">doi</span> = <span class="pl-s"><span class="pl-pds">{</span>10.1016/j.neucom.2022.10.013<span class="pl-pds">}</span></span>,
32
+ <span class="pl-s">url</span> = <span class="pl-s"><span class="pl-pds">{</span>https://www.sciencedirect.com/science/article/pii/S0925231222012656<span class="pl-pds">}</span></span>,
33
+ }</div>
34
+ """
model/config.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: config.py
3
+ Author: Elena Ryumina and Dmitry Ryumin
4
+ Description: Configuration file.
5
+ License: MIT License
6
+ """
7
+
8
+ import toml
9
+ from typing import Dict
10
+ from types import SimpleNamespace
11
+
12
+
13
+ def flatten_dict(prefix: str, d: Dict) -> Dict:
14
+ result = {}
15
+
16
+ for k, v in d.items():
17
+ if isinstance(v, dict):
18
+ result.update(flatten_dict(f"{prefix}{k}_", v))
19
+ else:
20
+ result[f"{prefix}{k}"] = v
21
+
22
+ return result
23
+
24
+
25
+ config = toml.load("config.toml")
26
+
27
+ config_data = flatten_dict("", config)
28
+
29
+ config_data = SimpleNamespace(**config_data)
30
+
31
+ DICT_EMO = {
32
+ 0: "Neutral",
33
+ 1: "Happiness",
34
+ 2: "Sadness",
35
+ 3: "Surprise",
36
+ 4: "Fear",
37
+ 5: "Disgust",
38
+ 6: "Anger",
39
+ }
40
+
41
+ COLORS = {
42
+ 0: 'blue',
43
+ 1: 'orange',
44
+ 2: 'green',
45
+ 3: 'red',
46
+ 4: 'purple',
47
+ 5: 'brown',
48
+ 6: 'pink'
49
+ }
model/description.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: description.py
3
+ Author: Elena Ryumina and Dmitry Ryumin
4
+ Description: Project description for the Gradio app.
5
+ License: MIT License
6
+ """
7
+
8
+ # Importing necessary components for the Gradio app
9
+ from app.config import config_data
10
+
11
+ DESCRIPTION_STATIC = f"""\
12
+ # Static Facial Expression Recognition
13
+ <div class="app-flex-container">
14
+ <img src="https://img.shields.io/badge/version-v{config_data.APP_VERSION}-rc0" alt="Version">
15
+ <a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FElenaRyumina%2FFacial_Expression_Recognition"><img src="https://api.visitorbadge.io/api/combined?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FElenaRyumina%2FFacial_Expression_Recognition&countColor=%23263759&style=flat" /></a>
16
+ <a href="https://paperswithcode.com/paper/in-search-of-a-robust-facial-expressions"><img src="https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/in-search-of-a-robust-facial-expressions/facial-expression-recognition-on-affectnet" /></a>
17
+ </div>
18
+ """
19
+
20
+ DESCRIPTION_DYNAMIC = f"""\
21
+ # Dynamic Facial Expression Recognition
22
+ <div class="app-flex-container">
23
+ <img src="https://img.shields.io/badge/version-v{config_data.APP_VERSION}-rc0" alt="Version">
24
+ <a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FElenaRyumina%2FFacial_Expression_Recognition"><img src="https://api.visitorbadge.io/api/combined?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FElenaRyumina%2FFacial_Expression_Recognition&countColor=%23263759&style=flat" /></a>
25
+ <a href="https://paperswithcode.com/paper/in-search-of-a-robust-facial-expressions"><img src="https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/in-search-of-a-robust-facial-expressions/facial-expression-recognition-on-affectnet" /></a>
26
+ </div>
27
+ """
model/face_utils.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: face_utils.py
3
+ Author: Elena Ryumina and Dmitry Ryumin
4
+ Description: This module contains utility functions related to facial landmarks and image processing.
5
+ License: MIT License
6
+ """
7
+
8
+ import numpy as np
9
+ import math
10
+ import cv2
11
+
12
+
13
+ def norm_coordinates(normalized_x, normalized_y, image_width, image_height):
14
+ x_px = min(math.floor(normalized_x * image_width), image_width - 1)
15
+ y_px = min(math.floor(normalized_y * image_height), image_height - 1)
16
+ return x_px, y_px
17
+
18
+
19
+ def get_box(fl, w, h):
20
+ idx_to_coors = {}
21
+ for idx, landmark in enumerate(fl.landmark):
22
+ landmark_px = norm_coordinates(landmark.x, landmark.y, w, h)
23
+ if landmark_px:
24
+ idx_to_coors[idx] = landmark_px
25
+
26
+ x_min = np.min(np.asarray(list(idx_to_coors.values()))[:, 0])
27
+ y_min = np.min(np.asarray(list(idx_to_coors.values()))[:, 1])
28
+ endX = np.max(np.asarray(list(idx_to_coors.values()))[:, 0])
29
+ endY = np.max(np.asarray(list(idx_to_coors.values()))[:, 1])
30
+
31
+ (startX, startY) = (max(0, x_min), max(0, y_min))
32
+ (endX, endY) = (min(w - 1, endX), min(h - 1, endY))
33
+
34
+ return startX, startY, endX, endY
35
+
36
+ def display_info(img, text, margin=1.0, box_scale=1.0):
37
+ img_copy = img.copy()
38
+ img_h, img_w, _ = img_copy.shape
39
+ line_width = int(min(img_h, img_w) * 0.001)
40
+ thickness = max(int(line_width / 3), 1)
41
+
42
+ font_face = cv2.FONT_HERSHEY_SIMPLEX
43
+ font_color = (0, 0, 0)
44
+ font_scale = thickness / 1.5
45
+
46
+ t_w, t_h = cv2.getTextSize(text, font_face, font_scale, None)[0]
47
+
48
+ margin_n = int(t_h * margin)
49
+ sub_img = img_copy[0 + margin_n: 0 + margin_n + t_h + int(2 * t_h * box_scale),
50
+ img_w - t_w - margin_n - int(2 * t_h * box_scale): img_w - margin_n]
51
+
52
+ white_rect = np.ones(sub_img.shape, dtype=np.uint8) * 255
53
+
54
+ img_copy[0 + margin_n: 0 + margin_n + t_h + int(2 * t_h * box_scale),
55
+ img_w - t_w - margin_n - int(2 * t_h * box_scale):img_w - margin_n] = cv2.addWeighted(sub_img, 0.5, white_rect, .5, 1.0)
56
+
57
+ cv2.putText(img=img_copy,
58
+ text=text,
59
+ org=(img_w - t_w - margin_n - int(2 * t_h * box_scale) // 2,
60
+ 0 + margin_n + t_h + int(2 * t_h * box_scale) // 2),
61
+ fontFace=font_face,
62
+ fontScale=font_scale,
63
+ color=font_color,
64
+ thickness=thickness,
65
+ lineType=cv2.LINE_AA,
66
+ bottomLeftOrigin=False)
67
+
68
+ return img_copy
model/model.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: model.py
3
+ Author: Elena Ryumina and Dmitry Ryumin
4
+ Description: This module provides functions for loading and processing a pre-trained deep learning model
5
+ for facial expression recognition.
6
+ License: MIT License
7
+ """
8
+
9
+ import torch
10
+ import requests
11
+ from PIL import Image
12
+ from torchvision import transforms
13
+ from pytorch_grad_cam import GradCAM
14
+
15
+ # Importing necessary components for the Gradio app
16
+ from app.config import config_data
17
+ from app.model_architectures import ResNet50, LSTMPyTorch
18
+
19
+
20
+ def load_model(model_url, model_path):
21
+ try:
22
+ with requests.get(model_url, stream=True) as response:
23
+ with open(model_path, "wb") as file:
24
+ for chunk in response.iter_content(chunk_size=8192):
25
+ file.write(chunk)
26
+ return model_path
27
+ except Exception as e:
28
+ print(f"Error loading model: {e}")
29
+ return None
30
+
31
+ path_static = load_model(config_data.model_static_url, config_data.model_static_path)
32
+ pth_model_static = ResNet50(7, channels=3)
33
+ pth_model_static.load_state_dict(torch.load(path_static))
34
+ pth_model_static.eval()
35
+
36
+ path_dynamic = load_model(config_data.model_dynamic_url, config_data.model_dynamic_path)
37
+ pth_model_dynamic = LSTMPyTorch()
38
+ pth_model_dynamic.load_state_dict(torch.load(path_dynamic))
39
+ pth_model_dynamic.eval()
40
+
41
+ target_layers = [pth_model_static.layer4]
42
+ cam = GradCAM(model=pth_model_static, target_layers=target_layers)
43
+
44
+ def pth_processing(fp):
45
+ class PreprocessInput(torch.nn.Module):
46
+ def init(self):
47
+ super(PreprocessInput, self).init()
48
+
49
+ def forward(self, x):
50
+ x = x.to(torch.float32)
51
+ x = torch.flip(x, dims=(0,))
52
+ x[0, :, :] -= 91.4953
53
+ x[1, :, :] -= 103.8827
54
+ x[2, :, :] -= 131.0912
55
+ return x
56
+
57
+ def get_img_torch(img, target_size=(224, 224)):
58
+ transform = transforms.Compose([transforms.PILToTensor(), PreprocessInput()])
59
+ img = img.resize(target_size, Image.Resampling.NEAREST)
60
+ img = transform(img)
61
+ img = torch.unsqueeze(img, 0)
62
+ return img
63
+
64
+ return get_img_torch(fp)
model/model_architectures.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: model.py
3
+ Author: Elena Ryumina and Dmitry Ryumin
4
+ Description: This module provides model architectures.
5
+ License: MIT License
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import math
12
+
13
+ class Bottleneck(nn.Module):
14
+ expansion = 4
15
+ def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
16
+ super(Bottleneck, self).__init__()
17
+
18
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False)
19
+ self.batch_norm1 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.99)
20
+
21
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same', bias=False)
22
+ self.batch_norm2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.99)
23
+
24
+ self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0, bias=False)
25
+ self.batch_norm3 = nn.BatchNorm2d(out_channels*self.expansion, eps=0.001, momentum=0.99)
26
+
27
+ self.i_downsample = i_downsample
28
+ self.stride = stride
29
+ self.relu = nn.ReLU()
30
+
31
+ def forward(self, x):
32
+ identity = x.clone()
33
+ x = self.relu(self.batch_norm1(self.conv1(x)))
34
+
35
+ x = self.relu(self.batch_norm2(self.conv2(x)))
36
+
37
+ x = self.conv3(x)
38
+ x = self.batch_norm3(x)
39
+
40
+ #downsample if needed
41
+ if self.i_downsample is not None:
42
+ identity = self.i_downsample(identity)
43
+ #add identity
44
+ x+=identity
45
+ x=self.relu(x)
46
+
47
+ return x
48
+
49
+ class Conv2dSame(torch.nn.Conv2d):
50
+
51
+ def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int:
52
+ return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
53
+
54
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
55
+ ih, iw = x.size()[-2:]
56
+
57
+ pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0])
58
+ pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1])
59
+
60
+ if pad_h > 0 or pad_w > 0:
61
+ x = F.pad(
62
+ x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
63
+ )
64
+ return F.conv2d(
65
+ x,
66
+ self.weight,
67
+ self.bias,
68
+ self.stride,
69
+ self.padding,
70
+ self.dilation,
71
+ self.groups,
72
+ )
73
+
74
+ class ResNet(nn.Module):
75
+ def __init__(self, ResBlock, layer_list, num_classes, num_channels=3):
76
+ super(ResNet, self).__init__()
77
+ self.in_channels = 64
78
+
79
+ self.conv_layer_s2_same = Conv2dSame(num_channels, 64, 7, stride=2, groups=1, bias=False)
80
+ self.batch_norm1 = nn.BatchNorm2d(64, eps=0.001, momentum=0.99)
81
+ self.relu = nn.ReLU()
82
+ self.max_pool = nn.MaxPool2d(kernel_size = 3, stride=2)
83
+
84
+ self.layer1 = self._make_layer(ResBlock, layer_list[0], planes=64, stride=1)
85
+ self.layer2 = self._make_layer(ResBlock, layer_list[1], planes=128, stride=2)
86
+ self.layer3 = self._make_layer(ResBlock, layer_list[2], planes=256, stride=2)
87
+ self.layer4 = self._make_layer(ResBlock, layer_list[3], planes=512, stride=2)
88
+
89
+ self.avgpool = nn.AdaptiveAvgPool2d((1,1))
90
+ self.fc1 = nn.Linear(512*ResBlock.expansion, 512)
91
+ self.relu1 = nn.ReLU()
92
+ self.fc2 = nn.Linear(512, num_classes)
93
+
94
+ def extract_features(self, x):
95
+ x = self.relu(self.batch_norm1(self.conv_layer_s2_same(x)))
96
+ x = self.max_pool(x)
97
+ # print(x.shape)
98
+ x = self.layer1(x)
99
+ x = self.layer2(x)
100
+ x = self.layer3(x)
101
+ x = self.layer4(x)
102
+
103
+ x = self.avgpool(x)
104
+ x = x.reshape(x.shape[0], -1)
105
+ x = self.fc1(x)
106
+ return x
107
+
108
+ def forward(self, x):
109
+ x = self.extract_features(x)
110
+ x = self.relu1(x)
111
+ x = self.fc2(x)
112
+ return x
113
+
114
+ def _make_layer(self, ResBlock, blocks, planes, stride=1):
115
+ ii_downsample = None
116
+ layers = []
117
+
118
+ if stride != 1 or self.in_channels != planes*ResBlock.expansion:
119
+ ii_downsample = nn.Sequential(
120
+ nn.Conv2d(self.in_channels, planes*ResBlock.expansion, kernel_size=1, stride=stride, bias=False, padding=0),
121
+ nn.BatchNorm2d(planes*ResBlock.expansion, eps=0.001, momentum=0.99)
122
+ )
123
+
124
+ layers.append(ResBlock(self.in_channels, planes, i_downsample=ii_downsample, stride=stride))
125
+ self.in_channels = planes*ResBlock.expansion
126
+
127
+ for i in range(blocks-1):
128
+ layers.append(ResBlock(self.in_channels, planes))
129
+
130
+ return nn.Sequential(*layers)
131
+
132
+ def ResNet50(num_classes, channels=3):
133
+ return ResNet(Bottleneck, [3,4,6,3], num_classes, channels)
134
+
135
+
136
+ class LSTMPyTorch(nn.Module):
137
+ def __init__(self):
138
+ super(LSTMPyTorch, self).__init__()
139
+
140
+ self.lstm1 = nn.LSTM(input_size=512, hidden_size=512, batch_first=True, bidirectional=False)
141
+ self.lstm2 = nn.LSTM(input_size=512, hidden_size=256, batch_first=True, bidirectional=False)
142
+ self.fc = nn.Linear(256, 7)
143
+ self.softmax = nn.Softmax(dim=1)
144
+
145
+ def forward(self, x):
146
+ x, _ = self.lstm1(x)
147
+ x, _ = self.lstm2(x)
148
+ x = self.fc(x[:, -1, :])
149
+ x = self.softmax(x)
150
+ return x