napatswift commited on
Commit
6ad06e7
·
1 Parent(s): 741d744

Chnage model arch

Browse files
Files changed (3) hide show
  1. main.py +3 -5
  2. model/text-det/psenet.pth +3 -0
  3. model/text-det/psenet.py +326 -0
main.py CHANGED
@@ -7,11 +7,9 @@ import torch
7
 
8
  print('Loading model...')
9
  device = 'gpu' if torch.cuda.is_available() else 'cpu'
10
- table_det = init_detector('model/table-det/config.py',
11
- 'model/table-det/model.pth', device=device)
12
 
13
- ocr = MMOCRInferencer(det='model/text-det/config.py',
14
- det_weights='model/text-det/model.pth',
15
  device=device)
16
 
17
  def get_rec(points):
@@ -39,4 +37,4 @@ def run():
39
 
40
 
41
  if __name__ == "__main__":
42
- run()
 
7
 
8
  print('Loading model...')
9
  device = 'gpu' if torch.cuda.is_available() else 'cpu'
 
 
10
 
11
+ ocr = MMOCRInferencer(det='model/text-det/psenet.py',
12
+ det_weights='model/text-det/psenet.pth',
13
  device=device)
14
 
15
  def get_rec(points):
 
37
 
38
 
39
  if __name__ == "__main__":
40
+ run()
model/text-det/psenet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8575eddcbed1c0a1151817ef05bb2df11a27979ea1f4a61fde5bbecd0c3e2595
3
+ size 352447333
model/text-det/psenet.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ file_client_args = dict(backend='disk')
2
+ model = dict(
3
+ type='PSENet',
4
+ backbone=dict(
5
+ type='mmdet.ResNet',
6
+ depth=50,
7
+ num_stages=4,
8
+ out_indices=(0, 1, 2, 3),
9
+ frozen_stages=-1,
10
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
11
+ init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'),
12
+ norm_eval=True,
13
+ style='caffe'),
14
+ neck=dict(
15
+ type='FPNF',
16
+ in_channels=[256, 512, 1024, 2048],
17
+ out_channels=256,
18
+ fusion_type='concat'),
19
+ det_head=dict(
20
+ type='PSEHead',
21
+ in_channels=[256],
22
+ hidden_dim=256,
23
+ out_channel=7,
24
+ module_loss=dict(type='PSEModuleLoss'),
25
+ postprocessor=dict(type='PSEPostprocessor', text_repr_type='poly')),
26
+ data_preprocessor=dict(
27
+ type='TextDetDataPreprocessor',
28
+ mean=[123.675, 116.28, 103.53],
29
+ std=[58.395, 57.12, 57.375],
30
+ bgr_to_rgb=True,
31
+ pad_size_divisor=32))
32
+ train_pipeline = [
33
+ dict(
34
+ type='LoadImageFromFile',
35
+ file_client_args=dict(backend='disk'),
36
+ color_type='color_ignore_orientation'),
37
+ dict(
38
+ type='LoadOCRAnnotations',
39
+ with_polygon=True,
40
+ with_bbox=True,
41
+ with_label=True),
42
+ dict(
43
+ type='TorchVisionWrapper',
44
+ op='ColorJitter',
45
+ brightness=0.12549019607843137,
46
+ saturation=0.5),
47
+ dict(type='FixInvalidPolygon'),
48
+ dict(type='ShortScaleAspectJitter', short_size=736, scale_divisor=32),
49
+ dict(type='RandomFlip', prob=0.5, direction='horizontal'),
50
+ dict(type='RandomRotate', max_angle=10),
51
+ dict(type='TextDetRandomCrop', target_size=(736, 736)),
52
+ dict(type='Pad', size=(736, 736)),
53
+ dict(
54
+ type='PackTextDetInputs',
55
+ meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
56
+ ]
57
+ test_pipeline = [
58
+ dict(
59
+ type='LoadImageFromFile',
60
+ file_client_args=dict(backend='disk'),
61
+ color_type='color_ignore_orientation'),
62
+ dict(type='Resize', scale=(2240, 2240), keep_ratio=True),
63
+ dict(
64
+ type='LoadOCRAnnotations',
65
+ with_polygon=True,
66
+ with_bbox=True,
67
+ with_label=True),
68
+ dict(
69
+ type='PackTextDetInputs',
70
+ meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
71
+ ]
72
+ thvc_textdet_data_root = 'data/det/vl+vc-textdet'
73
+ thvc_textdet_train = dict(
74
+ type='OCRDataset',
75
+ data_root='data/det/vl+vc-textdet',
76
+ ann_file='textdet_train.json',
77
+ data_prefix=dict(img_path='imgs/'),
78
+ filter_cfg=dict(filter_empty_gt=True, min_size=32),
79
+ pipeline=[
80
+ dict(
81
+ type='LoadImageFromFile',
82
+ file_client_args=dict(backend='disk'),
83
+ color_type='color_ignore_orientation'),
84
+ dict(
85
+ type='LoadOCRAnnotations',
86
+ with_polygon=True,
87
+ with_bbox=True,
88
+ with_label=True),
89
+ dict(
90
+ type='TorchVisionWrapper',
91
+ op='ColorJitter',
92
+ brightness=0.12549019607843137,
93
+ saturation=0.5),
94
+ dict(type='FixInvalidPolygon'),
95
+ dict(type='ShortScaleAspectJitter', short_size=736, scale_divisor=32),
96
+ dict(type='RandomFlip', prob=0.5, direction='horizontal'),
97
+ dict(type='RandomRotate', max_angle=10),
98
+ dict(type='TextDetRandomCrop', target_size=(736, 736)),
99
+ dict(type='Pad', size=(736, 736)),
100
+ dict(
101
+ type='PackTextDetInputs',
102
+ meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
103
+ ])
104
+ thvc_textdet_test = dict(
105
+ type='OCRDataset',
106
+ data_root='data/det/vl+vc-textdet',
107
+ ann_file='textdet_test.json',
108
+ data_prefix=dict(img_path='imgs/'),
109
+ test_mode=True,
110
+ pipeline=None)
111
+ thvote_textdet_data_root = 'data/det/textdet-thvote'
112
+ thvote_textdet_train = dict(
113
+ type='OCRDataset',
114
+ data_root='data/det/textdet-thvote',
115
+ ann_file='textdet_train.json',
116
+ data_prefix=dict(img_path='imgs/'),
117
+ filter_cfg=dict(filter_empty_gt=True, min_size=32),
118
+ pipeline=None)
119
+ thvote_textdet_test = dict(
120
+ type='OCRDataset',
121
+ data_root='data/det/textdet-thvote',
122
+ ann_file='textdet_test.json',
123
+ data_prefix=dict(img_path='imgs/'),
124
+ test_mode=True,
125
+ pipeline=[
126
+ dict(
127
+ type='LoadImageFromFile',
128
+ file_client_args=dict(backend='disk'),
129
+ color_type='color_ignore_orientation'),
130
+ dict(type='Resize', scale=(2240, 2240), keep_ratio=True),
131
+ dict(
132
+ type='LoadOCRAnnotations',
133
+ with_polygon=True,
134
+ with_bbox=True,
135
+ with_label=True),
136
+ dict(
137
+ type='PackTextDetInputs',
138
+ meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
139
+ ])
140
+ default_scope = 'mmocr'
141
+ env_cfg = dict(
142
+ cudnn_benchmark=True,
143
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
144
+ dist_cfg=dict(backend='nccl'))
145
+ randomness = dict(seed=None)
146
+ default_hooks = dict(
147
+ timer=dict(type='IterTimerHook'),
148
+ logger=dict(type='LoggerHook', interval=100),
149
+ param_scheduler=dict(type='ParamSchedulerHook'),
150
+ checkpoint=dict(type='CheckpointHook', interval=10),
151
+ sampler_seed=dict(type='DistSamplerSeedHook'),
152
+ sync_buffer=dict(type='SyncBuffersHook'),
153
+ visualization=dict(
154
+ type='VisualizationHook',
155
+ interval=1,
156
+ enable=False,
157
+ show=False,
158
+ draw_gt=False,
159
+ draw_pred=False))
160
+ log_level = 'INFO'
161
+ log_processor = dict(type='LogProcessor', window_size=10, by_epoch=True)
162
+ load_from = None
163
+ resume = True
164
+ val_evaluator = dict(type='HmeanIOUMetric')
165
+ test_evaluator = dict(type='HmeanIOUMetric')
166
+ vis_backends = [dict(type='LocalVisBackend')]
167
+ visualizer = dict(
168
+ type='TextDetLocalVisualizer',
169
+ name='visualizer',
170
+ vis_backends=[dict(type='LocalVisBackend')])
171
+ max_epochs = 200
172
+ optim_wrapper = dict(
173
+ type='OptimWrapper', optimizer=dict(type='Adam', lr=0.001))
174
+ train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=30, val_interval=10)
175
+ val_cfg = dict(type='ValLoop')
176
+ test_cfg = dict(type='TestLoop')
177
+ param_scheduler = [dict(type='PolyLR', power=0.9, end=200)]
178
+ thvotecount_textdet_train = dict(
179
+ type='OCRDataset',
180
+ data_root='data/det/vl+vc-textdet',
181
+ ann_file='textdet_train.json',
182
+ data_prefix=dict(img_path='imgs/'),
183
+ filter_cfg=dict(filter_empty_gt=True, min_size=32),
184
+ pipeline=[
185
+ dict(
186
+ type='LoadImageFromFile',
187
+ file_client_args=dict(backend='disk'),
188
+ color_type='color_ignore_orientation'),
189
+ dict(
190
+ type='LoadOCRAnnotations',
191
+ with_polygon=True,
192
+ with_bbox=True,
193
+ with_label=True),
194
+ dict(
195
+ type='TorchVisionWrapper',
196
+ op='ColorJitter',
197
+ brightness=0.12549019607843137,
198
+ saturation=0.5),
199
+ dict(type='FixInvalidPolygon'),
200
+ dict(type='ShortScaleAspectJitter', short_size=736, scale_divisor=32),
201
+ dict(type='RandomFlip', prob=0.5, direction='horizontal'),
202
+ dict(type='RandomRotate', max_angle=10),
203
+ dict(type='TextDetRandomCrop', target_size=(736, 736)),
204
+ dict(type='Pad', size=(736, 736)),
205
+ dict(
206
+ type='PackTextDetInputs',
207
+ meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
208
+ ])
209
+ thvotecount_textdet_test = dict(
210
+ type='OCRDataset',
211
+ data_root='data/det/textdet-thvote',
212
+ ann_file='textdet_test.json',
213
+ data_prefix=dict(img_path='imgs/'),
214
+ test_mode=True,
215
+ pipeline=[
216
+ dict(
217
+ type='LoadImageFromFile',
218
+ file_client_args=dict(backend='disk'),
219
+ color_type='color_ignore_orientation'),
220
+ dict(type='Resize', scale=(2240, 2240), keep_ratio=True),
221
+ dict(
222
+ type='LoadOCRAnnotations',
223
+ with_polygon=True,
224
+ with_bbox=True,
225
+ with_label=True),
226
+ dict(
227
+ type='PackTextDetInputs',
228
+ meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
229
+ ])
230
+ train_dataloader = dict(
231
+ batch_size=10,
232
+ num_workers=16,
233
+ persistent_workers=True,
234
+ sampler=dict(type='DefaultSampler', shuffle=True),
235
+ dataset=dict(
236
+ type='OCRDataset',
237
+ data_root='data/det/vl+vc-textdet',
238
+ ann_file='textdet_train.json',
239
+ data_prefix=dict(img_path='imgs/'),
240
+ filter_cfg=dict(filter_empty_gt=True, min_size=32),
241
+ pipeline=[
242
+ dict(
243
+ type='LoadImageFromFile',
244
+ file_client_args=dict(backend='disk'),
245
+ color_type='color_ignore_orientation'),
246
+ dict(
247
+ type='LoadOCRAnnotations',
248
+ with_polygon=True,
249
+ with_bbox=True,
250
+ with_label=True),
251
+ dict(
252
+ type='TorchVisionWrapper',
253
+ op='ColorJitter',
254
+ brightness=0.12549019607843137,
255
+ saturation=0.5),
256
+ dict(type='FixInvalidPolygon'),
257
+ dict(
258
+ type='ShortScaleAspectJitter',
259
+ short_size=736,
260
+ scale_divisor=32),
261
+ dict(type='RandomFlip', prob=0.5, direction='horizontal'),
262
+ dict(type='RandomRotate', max_angle=10),
263
+ dict(type='TextDetRandomCrop', target_size=(736, 736)),
264
+ dict(type='Pad', size=(736, 736)),
265
+ dict(
266
+ type='PackTextDetInputs',
267
+ meta_keys=('img_path', 'ori_shape', 'img_shape',
268
+ 'scale_factor'))
269
+ ]))
270
+ val_dataloader = dict(
271
+ batch_size=4,
272
+ num_workers=4,
273
+ persistent_workers=True,
274
+ sampler=dict(type='DefaultSampler', shuffle=False),
275
+ dataset=dict(
276
+ type='OCRDataset',
277
+ data_root='data/det/textdet-thvote',
278
+ ann_file='textdet_test.json',
279
+ data_prefix=dict(img_path='imgs/'),
280
+ test_mode=True,
281
+ pipeline=[
282
+ dict(
283
+ type='LoadImageFromFile',
284
+ file_client_args=dict(backend='disk'),
285
+ color_type='color_ignore_orientation'),
286
+ dict(type='Resize', scale=(2240, 2240), keep_ratio=True),
287
+ dict(
288
+ type='LoadOCRAnnotations',
289
+ with_polygon=True,
290
+ with_bbox=True,
291
+ with_label=True),
292
+ dict(
293
+ type='PackTextDetInputs',
294
+ meta_keys=('img_path', 'ori_shape', 'img_shape',
295
+ 'scale_factor'))
296
+ ]))
297
+ test_dataloader = dict(
298
+ batch_size=4,
299
+ num_workers=4,
300
+ persistent_workers=True,
301
+ sampler=dict(type='DefaultSampler', shuffle=False),
302
+ dataset=dict(
303
+ type='OCRDataset',
304
+ data_root='data/det/textdet-thvote',
305
+ ann_file='textdet_test.json',
306
+ data_prefix=dict(img_path='imgs/'),
307
+ test_mode=True,
308
+ pipeline=[
309
+ dict(
310
+ type='LoadImageFromFile',
311
+ file_client_args=dict(backend='disk'),
312
+ color_type='color_ignore_orientation'),
313
+ dict(type='Resize', scale=(2240, 2240), keep_ratio=True),
314
+ dict(
315
+ type='LoadOCRAnnotations',
316
+ with_polygon=True,
317
+ with_bbox=True,
318
+ with_label=True),
319
+ dict(
320
+ type='PackTextDetInputs',
321
+ meta_keys=('img_path', 'ori_shape', 'img_shape',
322
+ 'scale_factor'))
323
+ ]))
324
+ auto_scale_lr = dict(base_batch_size=32)
325
+ launcher = 'none'
326
+ work_dir = './work_dirs/psenet_resnet50_fpnf_votecount'