kavg commited on
Commit
a228fac
·
0 Parent(s):

Initial commit

Browse files
Files changed (10) hide show
  1. .gitignore +5 -0
  2. README.md +6 -0
  3. config.py +9 -0
  4. download_model.ipynb +144 -0
  5. main.py +66 -0
  6. models.py +236 -0
  7. ocr.py +89 -0
  8. preprocess.py +111 -0
  9. requirements.txt +0 -0
  10. token_classification.py +36 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ lilt-env/
2
+ .env
3
+ temp/
4
+ __pycache__/
5
+ models/
README.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ 1. Create a virtualenv
2
+ `virtualenv lilt-env`
3
+ 2. Install packages
4
+ `pip install -r requirements.txt`
5
+ 3. Run the app
6
+ `uvicorn main:app --reload`
config.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic_settings import BaseSettings, SettingsConfigDict
2
+ from pydantic import Field
3
+
4
+ class Settings(BaseSettings):
5
+ model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8')
6
+ GCV_AUTH: dict
7
+ SER_MODEL: str
8
+ TOKENIZER: str
9
+ RE_MODEL: str
download_model.ipynb ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "d:\\FYP\\lilt-app-without-fd\\lilt-env\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "from transformers import LiltModel, AutoTokenizer, LiltForTokenClassification"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "markdown",
23
+ "metadata": {},
24
+ "source": [
25
+ "## Download tokenizer"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": 2,
31
+ "metadata": {},
32
+ "outputs": [
33
+ {
34
+ "data": {
35
+ "text/plain": [
36
+ "('models/lilt-tokenizer\\\\tokenizer_config.json',\n",
37
+ " 'models/lilt-tokenizer\\\\special_tokens_map.json',\n",
38
+ " 'models/lilt-tokenizer\\\\tokenizer.json')"
39
+ ]
40
+ },
41
+ "execution_count": 2,
42
+ "metadata": {},
43
+ "output_type": "execute_result"
44
+ }
45
+ ],
46
+ "source": [
47
+ "TOKENIZER = 'nielsr/lilt-xlm-roberta-base'\n",
48
+ "tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)\n",
49
+ "save_dir = 'models/lilt-tokenizer'\n",
50
+ "tokenizer.save_pretrained(save_dir)"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "markdown",
55
+ "metadata": {},
56
+ "source": [
57
+ "## Download and save token classification model"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": 4,
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "# download the model\n",
67
+ "MODEL = \"pierreguillou/lilt-xlm-roberta-base-finetuned-funsd-iob-original\"\n",
68
+ "model = LiltForTokenClassification.from_pretrained(MODEL)\n",
69
+ "\n",
70
+ "# save the model\n",
71
+ "save_dir = \"models/lilt-ser-iob\"\n",
72
+ "model.save_pretrained(save_dir)"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "markdown",
77
+ "metadata": {},
78
+ "source": [
79
+ "## Download and save RE model"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": 5,
85
+ "metadata": {},
86
+ "outputs": [
87
+ {
88
+ "name": "stderr",
89
+ "output_type": "stream",
90
+ "text": [
91
+ "Downloading config.json: 100%|██████████| 794/794 [00:00<00:00, 61.2kB/s]\n",
92
+ "d:\\FYP\\lilt-app-without-fd\\lilt-env\\lib\\site-packages\\huggingface_hub\\file_download.py:133: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\\Users\\Gihantha Kavishka\\.cache\\huggingface\\hub. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.\n",
93
+ "To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development\n",
94
+ " warnings.warn(message)\n",
95
+ "Downloading pytorch_model.bin: 100%|██████████| 1.15G/1.15G [08:10<00:00, 2.34MB/s]\n",
96
+ "Some weights of the model checkpoint at kavg/layoutxlm-finetuned-xfund-fr-re were not used when initializing LiltModel: ['extractor.rel_classifier.linear.weight', 'extractor.entity_emb.weight', 'extractor.ffnn_tail.0.weight', 'extractor.ffnn_tail.3.bias', 'extractor.ffnn_head.3.weight', 'extractor.ffnn_head.0.weight', 'extractor.ffnn_tail.0.bias', 'extractor.ffnn_head.3.bias', 'extractor.rel_classifier.bilinear.weight', 'extractor.rel_classifier.linear.bias', 'extractor.ffnn_head.0.bias', 'extractor.ffnn_tail.3.weight']\n",
97
+ "- This IS expected if you are initializing LiltModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
98
+ "- This IS NOT expected if you are initializing LiltModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
99
+ "Some weights of LiltModel were not initialized from the model checkpoint at kavg/layoutxlm-finetuned-xfund-fr-re and are newly initialized: ['lilt.pooler.dense.bias', 'lilt.pooler.dense.weight']\n",
100
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
101
+ ]
102
+ }
103
+ ],
104
+ "source": [
105
+ "# download the model\n",
106
+ "MODEL = 'kavg/layoutxlm-finetuned-xfund-fr-re'\n",
107
+ "model = LiltModel.from_pretrained(MODEL)\n",
108
+ "\n",
109
+ "# save the model\n",
110
+ "save_dir = \"models/lilt-re\"\n",
111
+ "model.save_pretrained(save_dir)"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": []
120
+ }
121
+ ],
122
+ "metadata": {
123
+ "kernelspec": {
124
+ "display_name": "lilt-env",
125
+ "language": "python",
126
+ "name": "python3"
127
+ },
128
+ "language_info": {
129
+ "codemirror_mode": {
130
+ "name": "ipython",
131
+ "version": 3
132
+ },
133
+ "file_extension": ".py",
134
+ "mimetype": "text/x-python",
135
+ "name": "python",
136
+ "nbconvert_exporter": "python",
137
+ "pygments_lexer": "ipython3",
138
+ "version": "3.7.8"
139
+ },
140
+ "orig_nbformat": 4
141
+ },
142
+ "nbformat": 4,
143
+ "nbformat_minor": 2
144
+ }
main.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from config import Settings
2
+ from preprocess import Preprocessor
3
+ import ocr
4
+ from PIL import Image
5
+ from transformers import LiltForTokenClassification
6
+ import token_classification
7
+ import torch
8
+ from fastapi import FastAPI, UploadFile
9
+ from contextlib import asynccontextmanager
10
+ import json
11
+ import io
12
+ from models import LiLTRobertaLikeForRelationExtraction
13
+ config = {}
14
+
15
+ @asynccontextmanager
16
+ async def lifespan(app: FastAPI):
17
+ settings = Settings()
18
+ config['settings'] = settings
19
+ config['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ config['vision_client'] = ocr.VisionClient(settings.GCV_AUTH)
21
+ config['processor'] = Preprocessor(settings.TOKENIZER)
22
+ config['ser_model'] = LiltForTokenClassification.from_pretrained(settings.SER_MODEL)
23
+ config['re_model'] = LiLTRobertaLikeForRelationExtraction.from_pretrained(settings.RE_MODEL)
24
+ yield
25
+ # Clean up and release the resources
26
+ config.clear()
27
+
28
+ app = FastAPI(lifespan=lifespan)
29
+
30
+ @app.post("/submit-doc")
31
+ async def ProcessDocument(file: UploadFile):
32
+ tokenClassificationOutput = await LabelTokens(file)
33
+ reOutput = ExtractRelations(tokenClassificationOutput)
34
+ return reOutput
35
+
36
+ async def LabelTokens(file):
37
+ content = await file.read()
38
+ image = Image.open(io.BytesIO(content))
39
+ ocr_df = config['vision_client'].ocr(content, image)
40
+ input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
41
+ token_labels = token_classification.classifyTokens(config['ser_model'], input_ids, attention_mask, bbox, offset_mapping)
42
+ return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "offset_mapping":offset_mapping, "attention_mask":attention_mask}
43
+
44
+ def ExtractRelations(tokenClassificationOutput):
45
+ token_labels = tokenClassificationOutput['token_labels']
46
+ input_ids = tokenClassificationOutput['input_ids']
47
+ offset_mapping = tokenClassificationOutput["offset_mapping"]
48
+ attention_mask = tokenClassificationOutput["attention_mask"]
49
+ bbox = tokenClassificationOutput["bbox"]
50
+
51
+ entities = token_classification.createEntities(config['ser_model'], token_labels, input_ids, offset_mapping)
52
+
53
+ config['re_model'].to(config['device'])
54
+ entity_dict = {'start': [entity[0] for entity in entities], 'end': [entity[1] for entity in entities], 'label': [entity[3] for entity in entities]}
55
+ relations = [{'start_index': [], 'end_index': [], 'head': [], 'tail': []}]
56
+ with torch.no_grad():
57
+ outputs = config['re_model'](input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, entities=[entity_dict], relations=relations)
58
+
59
+ print(type(outputs.pred_relations[0]))
60
+ print(type(entities))
61
+ print(type(input_ids))
62
+ print(type(bbox))
63
+ print(type(token_labels))
64
+ # "pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()),
65
+
66
+ return {"pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()), "bboxes": json.dumps(bbox.tolist()),"token_labels":json.dumps(token_labels)}
models.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LiltPreTrainedModel, LiltModel
2
+ import copy
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import CrossEntropyLoss
6
+ from dataclasses import dataclass
7
+ from typing import Dict, Optional, Tuple
8
+ from transformers.utils import ModelOutput
9
+
10
+ class BiaffineAttention(torch.nn.Module):
11
+ """Implements a biaffine attention operator for binary relation classification.
12
+
13
+ PyTorch implementation of the biaffine attention operator from "End-to-end neural relation
14
+ extraction using deep biaffine attention" (https://arxiv.org/abs/1812.11275) which can be used
15
+ as a classifier for binary relation classification.
16
+
17
+ Args:
18
+ in_features (int): The size of the feature dimension of the inputs.
19
+ out_features (int): The size of the feature dimension of the output.
20
+
21
+ Shape:
22
+ - x_1: `(N, *, in_features)` where `N` is the batch dimension and `*` means any number of
23
+ additional dimensisons.
24
+ - x_2: `(N, *, in_features)`, where `N` is the batch dimension and `*` means any number of
25
+ additional dimensions.
26
+ - Output: `(N, *, out_features)`, where `N` is the batch dimension and `*` means any number
27
+ of additional dimensions.
28
+
29
+ Examples:
30
+ >>> batch_size, in_features, out_features = 32, 100, 4
31
+ >>> biaffine_attention = BiaffineAttention(in_features, out_features)
32
+ >>> x_1 = torch.randn(batch_size, in_features)
33
+ >>> x_2 = torch.randn(batch_size, in_features)
34
+ >>> output = biaffine_attention(x_1, x_2)
35
+ >>> print(output.size())
36
+ torch.Size([32, 4])
37
+ """
38
+
39
+ def __init__(self, in_features, out_features):
40
+ super(BiaffineAttention, self).__init__()
41
+
42
+ self.in_features = in_features
43
+ self.out_features = out_features
44
+
45
+ self.bilinear = torch.nn.Bilinear(in_features, in_features, out_features, bias=False)
46
+ self.linear = torch.nn.Linear(2 * in_features, out_features, bias=True)
47
+
48
+ self.reset_parameters()
49
+
50
+ def forward(self, x_1, x_2):
51
+ return self.bilinear(x_1, x_2) + self.linear(torch.cat((x_1, x_2), dim=-1))
52
+
53
+ def reset_parameters(self):
54
+ self.bilinear.reset_parameters()
55
+ self.linear.reset_parameters()
56
+
57
+
58
+ class REDecoder(nn.Module):
59
+ def __init__(self, config, input_size):
60
+ super().__init__()
61
+ self.entity_emb = nn.Embedding(3, input_size, scale_grad_by_freq=True)
62
+ projection = nn.Sequential(
63
+ nn.Linear(input_size * 2, config.hidden_size),
64
+ nn.ReLU(),
65
+ nn.Dropout(config.hidden_dropout_prob),
66
+ nn.Linear(config.hidden_size, config.hidden_size // 2),
67
+ nn.ReLU(),
68
+ nn.Dropout(config.hidden_dropout_prob),
69
+ )
70
+ self.ffnn_head = copy.deepcopy(projection)
71
+ self.ffnn_tail = copy.deepcopy(projection)
72
+ self.rel_classifier = BiaffineAttention(config.hidden_size // 2, 2)
73
+ self.loss_fct = CrossEntropyLoss()
74
+
75
+ def build_relation(self, relations, entities):
76
+ batch_size = len(relations)
77
+ new_relations = []
78
+ for b in range(batch_size):
79
+ if len(entities[b]["start"]) <= 2:
80
+ entities[b] = {"end": [1, 1], "label": [0, 0], "start": [0, 0]}
81
+ all_possible_relations = set(
82
+ [
83
+ (i, j)
84
+ for i in range(len(entities[b]["label"]))
85
+ for j in range(len(entities[b]["label"]))
86
+ if entities[b]["label"][i] == 1 and entities[b]["label"][j] == 2
87
+ ]
88
+ )
89
+ if len(all_possible_relations) == 0:
90
+ all_possible_relations = set([(0, 1)])
91
+ positive_relations = set(list(zip(relations[b]["head"], relations[b]["tail"])))
92
+ negative_relations = all_possible_relations - positive_relations
93
+ positive_relations = set([i for i in positive_relations if i in all_possible_relations])
94
+ reordered_relations = list(positive_relations) + list(negative_relations)
95
+ relation_per_doc = {"head": [], "tail": [], "label": []}
96
+ relation_per_doc["head"] = [i[0] for i in reordered_relations]
97
+ relation_per_doc["tail"] = [i[1] for i in reordered_relations]
98
+ relation_per_doc["label"] = [1] * len(positive_relations) + [0] * (
99
+ len(reordered_relations) - len(positive_relations)
100
+ )
101
+ assert len(relation_per_doc["head"]) != 0
102
+ new_relations.append(relation_per_doc)
103
+ return new_relations, entities
104
+
105
+ def get_predicted_relations(self, logits, relations, entities):
106
+ pred_relations = []
107
+ for i, pred_label in enumerate(logits.argmax(-1)):
108
+ if pred_label != 1:
109
+ continue
110
+ rel = {}
111
+ rel["head_id"] = relations["head"][i]
112
+ rel["head"] = (entities["start"][rel["head_id"]], entities["end"][rel["head_id"]])
113
+ rel["head_type"] = entities["label"][rel["head_id"]]
114
+
115
+ rel["tail_id"] = relations["tail"][i]
116
+ rel["tail"] = (entities["start"][rel["tail_id"]], entities["end"][rel["tail_id"]])
117
+ rel["tail_type"] = entities["label"][rel["tail_id"]]
118
+ rel["type"] = 1
119
+ pred_relations.append(rel)
120
+ return pred_relations
121
+
122
+ def forward(self, hidden_states, entities, relations):
123
+ batch_size, max_n_words, context_dim = hidden_states.size()
124
+ device = hidden_states.device
125
+ relations, entities = self.build_relation(relations, entities)
126
+ loss = 0
127
+ all_pred_relations = []
128
+ all_logits = []
129
+ all_labels = []
130
+
131
+ for b in range(batch_size):
132
+ head_entities = torch.tensor(relations[b]["head"], device=device)
133
+ tail_entities = torch.tensor(relations[b]["tail"], device=device)
134
+ relation_labels = torch.tensor(relations[b]["label"], device=device)
135
+ entities_start_index = torch.tensor(entities[b]["start"], device=device)
136
+ entities_labels = torch.tensor(entities[b]["label"], device=device)
137
+ head_index = entities_start_index[head_entities]
138
+ head_label = entities_labels[head_entities]
139
+ head_label_repr = self.entity_emb(head_label)
140
+
141
+ tail_index = entities_start_index[tail_entities]
142
+ tail_label = entities_labels[tail_entities]
143
+ tail_label_repr = self.entity_emb(tail_label)
144
+
145
+ head_repr = torch.cat(
146
+ (hidden_states[b][head_index], head_label_repr),
147
+ dim=-1,
148
+ )
149
+ tail_repr = torch.cat(
150
+ (hidden_states[b][tail_index], tail_label_repr),
151
+ dim=-1,
152
+ )
153
+ heads = self.ffnn_head(head_repr)
154
+ tails = self.ffnn_tail(tail_repr)
155
+ logits = self.rel_classifier(heads, tails)
156
+ pred_relations = self.get_predicted_relations(logits, relations[b], entities[b])
157
+ all_pred_relations.append(pred_relations)
158
+ all_logits.append(logits)
159
+ all_labels.append(relation_labels)
160
+ all_logits = torch.cat(all_logits, 0)
161
+ all_labels = torch.cat(all_labels, 0)
162
+ loss = self.loss_fct(all_logits, all_labels)
163
+ return loss, all_pred_relations
164
+
165
+
166
+ @dataclass
167
+ class ReOutput(ModelOutput):
168
+ loss: Optional[torch.FloatTensor] = None
169
+ logits: torch.FloatTensor = None
170
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
171
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
172
+ entities: Optional[Dict] = None
173
+ relations: Optional[Dict] = None
174
+ pred_relations: Optional[Dict] = None
175
+
176
+ class REHead(nn.Module):
177
+ def __init__(self, config):
178
+ super().__init__()
179
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
180
+ self.extractor = REDecoder(config, config.hidden_size)
181
+
182
+ def forward(self,sequence_output, entities, relations):
183
+ sequence_output = self.dropout(sequence_output)
184
+ loss, pred_relations = self.extractor(sequence_output, entities, relations)
185
+ return ReOutput(
186
+ loss=loss,
187
+ entities=entities,
188
+ relations=relations,
189
+ pred_relations=pred_relations,
190
+ )
191
+
192
+ class LiLTRobertaLikeForRelationExtraction(LiltPreTrainedModel):
193
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
194
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
195
+ def __init__(self, config):
196
+ super().__init__(config)
197
+
198
+ self.lilt = LiltModel(config, add_pooling_layer=False)
199
+ self.rehead = REHead(config)
200
+ self.init_weights()
201
+
202
+
203
+ def forward(
204
+ self,
205
+ input_ids=None,
206
+ bbox=None,
207
+ attention_mask=None,
208
+ token_type_ids=None,
209
+ position_ids=None,
210
+ head_mask=None,
211
+ inputs_embeds=None,
212
+ labels=None,
213
+ output_attentions=None,
214
+ output_hidden_states=None,
215
+ return_dict=None,
216
+ entities=None,
217
+ relations=None,
218
+ ):
219
+
220
+ outputs = self.lilt(
221
+ input_ids,
222
+ bbox=bbox,
223
+ attention_mask=attention_mask,
224
+ token_type_ids=token_type_ids,
225
+ position_ids=position_ids,
226
+ head_mask=head_mask,
227
+ inputs_embeds=inputs_embeds,
228
+ output_attentions=output_attentions,
229
+ output_hidden_states=output_hidden_states,
230
+ return_dict=return_dict,
231
+ )
232
+
233
+ sequence_output = outputs[0]
234
+
235
+ re_output = self.rehead(sequence_output, entities, relations)
236
+ return re_output
ocr.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from google.cloud import vision
2
+ from google.oauth2 import service_account
3
+ from google.protobuf.json_format import MessageToJson
4
+ import pandas as pd
5
+ import json
6
+ import numpy as np
7
+ from PIL import Image
8
+ import io
9
+
10
+ image_ext = ("*.jpg", "*.jpeg", "*.png")
11
+
12
+ class VisionClient:
13
+ def __init__(self, auth):
14
+ credentials = service_account.Credentials.from_service_account_info(
15
+ auth
16
+ )
17
+ self.client = vision.ImageAnnotatorClient(credentials=credentials)
18
+
19
+ def send_request(self, image):
20
+ try:
21
+ image = vision.Image(content=image)
22
+ except ValueError as e:
23
+ print("Image could not be read")
24
+ return
25
+ response = self.client.document_text_detection(image, timeout=10)
26
+ return response
27
+
28
+ def get_response(self, content):
29
+ try:
30
+ resp_js = self.send_request(content)
31
+ except Exception as e:
32
+ print("OCR request failed. Reason : {}".format(e))
33
+
34
+ return resp_js
35
+
36
+ def post_process(self, resp_js):
37
+ boxObjects = []
38
+ for i in range(1, len(resp_js.text_annotations)):
39
+ # We need to do that because vision sometimes reverse the left and right coords so then we have negative
40
+ # width which causes problems when drawing link buttons
41
+ obj = resp_js
42
+ if obj.text_annotations[i].bounding_poly.vertices[1].x > obj.text_annotations[i].bounding_poly.vertices[3].x:
43
+ leftX = obj.text_annotations[i].bounding_poly.vertices[3].x
44
+ else:
45
+ leftX = obj.text_annotations[i].bounding_poly.vertices[1].x
46
+
47
+ if obj.text_annotations[i].bounding_poly.vertices[1].x > obj.text_annotations[i].bounding_poly.vertices[3].x:
48
+ rightX = obj.text_annotations[i].bounding_poly.vertices[1].x
49
+ else:
50
+ rightX = obj.text_annotations[i].bounding_poly.vertices[3].x
51
+
52
+ boxObjects.append({
53
+ "id": i-1,
54
+ "text": obj.text_annotations[i].description,
55
+ "left": leftX,
56
+ "width": rightX - leftX,
57
+ "top": obj.text_annotations[i].bounding_poly.vertices[1].y,
58
+ "height":obj.text_annotations[i].bounding_poly.vertices[3].y - obj.text_annotations[i].bounding_poly.vertices[1].y
59
+ })
60
+
61
+ return boxObjects
62
+
63
+ def convert_to_df(self, boxObjects, image):
64
+ ocr_df = pd.DataFrame(boxObjects)
65
+
66
+ # ocr_df = ocr_df.sort_values(by=['top', 'left'], ascending=True).reset_index(drop=True)
67
+ width, height = image.size
68
+ w_scale = 1000/width
69
+ h_scale = 1000/height
70
+
71
+ ocr_df = ocr_df.dropna() \
72
+ .assign(left_scaled = ocr_df.left*w_scale,
73
+ width_scaled = ocr_df.width*w_scale,
74
+ top_scaled = ocr_df.top*h_scale,
75
+ height_scaled = ocr_df.height*h_scale,
76
+ right_scaled = lambda x: x.left_scaled + x.width_scaled,
77
+ bottom_scaled = lambda x: x.top_scaled + x.height_scaled)
78
+
79
+ float_cols = ocr_df.select_dtypes('float').columns
80
+ ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int)
81
+ ocr_df = ocr_df.replace(r'^\s*$', np.nan, regex=True)
82
+ ocr_df = ocr_df.dropna().reset_index(drop=True)
83
+ return ocr_df
84
+
85
+ def ocr(self, content, image):
86
+ resp_js = self.get_response(content)
87
+ boxObjects = self.post_process(resp_js)
88
+ ocr_df = self.convert_to_df(boxObjects, image)
89
+ return ocr_df
preprocess.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer
3
+
4
+ # class to turn the keys of a dict into attributes (thanks Stackoverflow)
5
+ class AttrDict(dict):
6
+ def __init__(self, *args, **kwargs):
7
+ super(AttrDict, self).__init__(*args, **kwargs)
8
+ self.__dict__ = self
9
+
10
+ class Preprocessor():
11
+ def __init__(self, tokenizer):
12
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
13
+ self.argsdict = {'max_seq_length': 512}
14
+ self.args = AttrDict(self.argsdict)
15
+
16
+ def get_boxes(self, ocr_df, image):
17
+ words = list(ocr_df.text)
18
+ coordinates = ocr_df[['left', 'top', 'width', 'height']]
19
+ actual_boxes = []
20
+ width, height = image.size
21
+ for idx, row in coordinates.iterrows():
22
+ x, y, w, h = tuple(row) # the row comes in (left, top, width, height) format
23
+ actual_box = [x, y, x+w, y+h] # we turn it into (left, top, left+widght, top+height) to get the actual box
24
+ actual_boxes.append(actual_box)
25
+
26
+ def normalize_box(box, width, height):
27
+ return [
28
+ int(1000 * (box[0] / width)),
29
+ int(1000 * (box[1] / height)),
30
+ int(1000 * (box[2] / width)),
31
+ int(1000 * (box[3] / height)),
32
+ ]
33
+
34
+ boxes = []
35
+ for box in actual_boxes:
36
+ boxes.append(normalize_box(box, width, height))
37
+
38
+ return words, boxes, actual_boxes
39
+
40
+ def convert_example_to_features(self, image, words, boxes, actual_boxes, cls_token_box=[0, 0, 0, 0],
41
+ sep_token_box=[1000, 1000, 1000, 1000],
42
+ pad_token_box=[0, 0, 0, 0]):
43
+ width, height = image.size
44
+
45
+ tokens = []
46
+ token_boxes = []
47
+ actual_bboxes = [] # we use an extra b because actual_boxes is already used
48
+ token_actual_boxes = []
49
+ offset_mapping = []
50
+ for word, box, actual_bbox in zip(words, boxes, actual_boxes):
51
+ word_tokens = self.tokenizer.tokenize(word)
52
+ mapping = self.tokenizer(word, return_offsets_mapping=True).offset_mapping
53
+ offset_mapping.extend(mapping)
54
+ tokens.extend(word_tokens)
55
+ token_boxes.extend([box] * len(word_tokens))
56
+ actual_bboxes.extend([actual_bbox] * len(word_tokens))
57
+ token_actual_boxes.extend([actual_bbox] * len(word_tokens))
58
+
59
+ # Truncation: account for [CLS] and [SEP] with "- 2".
60
+ special_tokens_count = 2
61
+ if len(tokens) > self.args.max_seq_length - special_tokens_count:
62
+ tokens = tokens[: (self.args.max_seq_length - special_tokens_count)]
63
+ token_boxes = token_boxes[: (self.args.max_seq_length - special_tokens_count)]
64
+ actual_bboxes = actual_bboxes[: (self.args.max_seq_length - special_tokens_count)]
65
+ token_actual_boxes = token_actual_boxes[: (self.args.max_seq_length - special_tokens_count)]
66
+
67
+ # add [SEP] token, with corresponding token boxes and actual boxes
68
+ tokens += [self.tokenizer.sep_token]
69
+ token_boxes += [sep_token_box]
70
+ actual_bboxes += [[0, 0, width, height]]
71
+ token_actual_boxes += [[0, 0, width, height]]
72
+
73
+ segment_ids = [0] * len(tokens)
74
+
75
+ # next: [CLS] token
76
+ tokens = [self.tokenizer.cls_token] + tokens
77
+ token_boxes = [cls_token_box] + token_boxes
78
+ actual_bboxes = [[0, 0, width, height]] + actual_bboxes
79
+ token_actual_boxes = [[0, 0, width, height]] + token_actual_boxes
80
+ segment_ids = [1] + segment_ids
81
+
82
+ input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
83
+
84
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
85
+ # tokens are attended to.
86
+ input_mask = [1] * len(input_ids)
87
+
88
+ # Zero-pad up to the sequence length.
89
+ padding_length = self.args.max_seq_length - len(input_ids)
90
+ input_ids += [self.tokenizer.pad_token_id] * padding_length
91
+ input_mask += [0] * padding_length
92
+ segment_ids += [self.tokenizer.pad_token_id] * padding_length
93
+ token_boxes += [pad_token_box] * padding_length
94
+ token_actual_boxes += [pad_token_box] * padding_length
95
+
96
+ assert len(input_ids) == self.args.max_seq_length
97
+ assert len(input_mask) == self.args.max_seq_length
98
+ assert len(segment_ids) == self.args.max_seq_length
99
+ assert len(token_boxes) == self.args.max_seq_length
100
+ assert len(token_actual_boxes) == self.args.max_seq_length
101
+
102
+ return input_ids, input_mask, segment_ids, token_boxes, token_actual_boxes, offset_mapping
103
+
104
+ def process(self, ocr_df, image):
105
+ words, boxes, actual_boxes = self.get_boxes(ocr_df, image)
106
+ input_ids, input_mask, segment_ids, token_boxes, token_actual_boxes, offset_mapping = self.convert_example_to_features(image=image, words=words, boxes=boxes, actual_boxes=actual_boxes)
107
+ input_ids = torch.tensor(input_ids).unsqueeze(0)
108
+ attention_mask = torch.tensor(input_mask).unsqueeze(0)
109
+ token_type_ids = torch.tensor(segment_ids).unsqueeze(0)
110
+ bbox = torch.tensor(token_boxes).unsqueeze(0)
111
+ return input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping
requirements.txt ADDED
Binary file (3.27 kB). View file
 
token_classification.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def classifyTokens(model, input_ids, attention_mask, bbox, offset_mapping):
4
+ outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
5
+ # take argmax on last dimension to get predicted class ID per token
6
+ predictions = outputs.logits.argmax(-1).squeeze().tolist()
7
+ return predictions
8
+
9
+ def createEntities(model, predictions, input_ids, offset_mapping):
10
+ # we're only interested in tokens which aren't subwords
11
+ # we'll use the offset mapping for that
12
+ offset_mapping = np.array(offset_mapping)
13
+ is_subword = np.array(offset_mapping.squeeze().tolist())[:,0] != 0
14
+
15
+ id2label = {"HEADER":0, "QUESTION":1, "ANSWER":2}
16
+
17
+ # finally, store recognized "question" and "answer" entities in a list
18
+ entities = []
19
+ current_entity = None
20
+ start = None
21
+ end = None
22
+
23
+ for idx, (id, pred) in enumerate(zip(input_ids[0].tolist(), predictions)):
24
+ if not is_subword[idx]:
25
+ predicted_label = model.config.id2label[pred]
26
+ if predicted_label.startswith("B") and current_entity is None:
27
+ # means we're at the start of a new entity
28
+ current_entity = predicted_label.replace("B-", "")
29
+ start = idx
30
+ if current_entity is not None and current_entity not in predicted_label:
31
+ # means we're at the end of a new entity
32
+ end = idx
33
+ entities.append((start, end, current_entity, id2label[current_entity]))
34
+ current_entity = None
35
+
36
+ return entities