Jyiyiyiyi commited on
Commit
982e027
·
1 Parent(s): 140a8c4

Upload MarkuplmTransformerForConMATH.py

Browse files
0_Asym/140000857040976_MarkuplmTransformerForConMATH/MarkuplmTransformerForConMATH.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from transformers import AutoModel, AutoTokenizer, AutoConfig, T5Config, MT5Config, MarkupLMProcessor, MarkupLMTokenizer
3
+ import json
4
+ from typing import List, Dict, Optional, Union, Tuple
5
+ import os
6
+
7
+
8
+ class MarkuplmTransformerForConMATH(nn.Module):
9
+ """Huggingface AutoModel to generate token embeddings.
10
+ Loads the correct class, e.g. BERT / RoBERTa etc.
11
+
12
+ :param model_name_or_path: Huggingface models name (https://huggingface.co/models)
13
+ :param max_seq_length: Truncate any inputs longer than max_seq_length
14
+ :param model_args: Arguments (key, value pairs) passed to the Huggingface Transformers model
15
+ :param cache_dir: Cache dir for Huggingface Transformers to store/load models
16
+ :param tokenizer_args: Arguments (key, value pairs) passed to the Huggingface Tokenizer model
17
+ :param do_lower_case: If true, lowercases the input (independent if the model is cased or not)
18
+ :param tokenizer_name_or_path: Name or path of the tokenizer. When None, then model_name_or_path is used
19
+ """
20
+ def __init__(self, model_name_or_path: str, max_seq_length: Optional[int] = None,
21
+ model_args: Dict = {}, cache_dir: Optional[str] = None,
22
+ tokenizer_args: Dict = {}, processor_args: Dict = {}, do_lower_case: bool = False,
23
+ tokenizer_name_or_path : str = None, processor_name_or_path : str = None):
24
+ super(MarkuplmTransformerForConMATH, self).__init__()
25
+ self.config_keys = ['max_seq_length', 'do_lower_case']
26
+ self.do_lower_case = do_lower_case
27
+
28
+ config = AutoConfig.from_pretrained(model_name_or_path, **model_args, cache_dir=cache_dir)
29
+ #print('config:' + str(config))
30
+
31
+ self._load_model(model_name_or_path, config, cache_dir, **model_args)
32
+
33
+ self.tokenizer = MarkupLMTokenizer.from_pretrained(tokenizer_name_or_path if tokenizer_name_or_path is not None else model_name_or_path, cache_dir=cache_dir, **tokenizer_args)
34
+ self.processor = MarkupLMProcessor.from_pretrained(processor_name_or_path if processor_name_or_path is not None else model_name_or_path, cache_dir=cache_dir, **processor_args)
35
+
36
+ #No max_seq_length set. Try to infer from model
37
+ if max_seq_length is None:
38
+ #if hasattr(self.auto_model, "config") and hasattr(self.auto_model.config, "max_position_embeddings") and hasattr(self.tokenizer, "model_max_length"):
39
+ #max_seq_length = min(self.auto_model.config.max_position_embeddings, self.processor.model_max_length)
40
+ max_seq_length = 512
41
+
42
+ self.max_seq_length = max_seq_length
43
+
44
+ if tokenizer_name_or_path is not None:
45
+ self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__
46
+
47
+
48
+ def _load_model(self, model_name_or_path, config, cache_dir, **model_args):
49
+ """Loads the transformer model"""
50
+ if isinstance(config, T5Config):
51
+ self._load_t5_model(model_name_or_path, config, cache_dir, **model_args)
52
+ elif isinstance(config, MT5Config):
53
+ self._load_mt5_model(model_name_or_path, config, cache_dir, **model_args)
54
+ else:
55
+ self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir, **model_args)
56
+
57
+ def _load_t5_model(self, model_name_or_path, config, cache_dir, **model_args):
58
+ """Loads the encoder model from T5"""
59
+ from transformers import T5EncoderModel
60
+ T5EncoderModel._keys_to_ignore_on_load_unexpected = ["decoder.*"]
61
+ self.auto_model = T5EncoderModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir, **model_args)
62
+
63
+ def _load_mt5_model(self, model_name_or_path, config, cache_dir, **model_args):
64
+ """Loads the encoder model from T5"""
65
+ from transformers import MT5EncoderModel
66
+ MT5EncoderModel._keys_to_ignore_on_load_unexpected = ["decoder.*"]
67
+ self.auto_model = MT5EncoderModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir, **model_args)
68
+
69
+ def __repr__(self):
70
+ return "MarkuplmTransformer({}) with Transformer model: {} ".format(self.get_config_dict(), self.auto_model.__class__.__name__)
71
+
72
+ def forward(self, features):
73
+ """Returns token_embeddings, cls_token"""
74
+ trans_features = {'input_ids': features['input_ids'], 'xpath_tags_seq': features['xpath_tags_seq'],
75
+ 'xpath_subs_seq': features['xpath_subs_seq'], 'attention_mask': features['attention_mask']}
76
+ if 'token_type_ids' in features:
77
+ trans_features['token_type_ids'] = features['token_type_ids']
78
+
79
+ output_states = self.auto_model(**trans_features, return_dict=False)
80
+ output_tokens = output_states[0]
81
+
82
+ features.update({'token_embeddings': output_tokens, 'attention_mask': features['attention_mask']})
83
+
84
+ if self.auto_model.config.output_hidden_states:
85
+ all_layer_idx = 2
86
+ if len(output_states) < 3: #Some models only output last_hidden_states and all_hidden_states
87
+ all_layer_idx = 1
88
+
89
+ hidden_states = output_states[all_layer_idx]
90
+ features.update({'all_layer_embeddings': hidden_states})
91
+
92
+ return features
93
+
94
+ def get_word_embedding_dimension(self) -> int:
95
+ return self.auto_model.config.hidden_size
96
+
97
+ def tokenize(self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]]):
98
+ #print("**************mathml process***************")
99
+ output = {}
100
+ if isinstance(texts[0], str):
101
+ to_process = [texts]
102
+ elif isinstance(texts[0], dict):
103
+ to_process = []
104
+ output['text_keys'] = []
105
+ for lookup in texts:
106
+ text_key, text = next(iter(lookup.items()))
107
+ to_process.append(text)
108
+ output['text_keys'].append(text_key)
109
+ to_process = [to_process]
110
+ else:
111
+ batch1, batch2 = [], []
112
+ for text_tuple in texts:
113
+ batch1.append(text_tuple[0])
114
+ batch2.append(text_tuple[1])
115
+ to_process = [batch1, batch2]
116
+
117
+ # strip
118
+ to_process = [[str(s).strip() for s in col] for col in to_process]
119
+
120
+ # Lowercase
121
+ if self.do_lower_case:
122
+ to_process = [[s.lower() for s in col] for col in to_process]
123
+
124
+ output.update(self.processor(*to_process, padding=True, truncation='longest_first', return_tensors="pt",
125
+ max_length=self.max_seq_length))
126
+ return output
127
+ '''
128
+ def tokenize(self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]]):
129
+ """
130
+ Tokenizes a text and maps tokens to token-ids
131
+ """
132
+ output = {}
133
+ if isinstance(texts[0], str):
134
+ to_tokenize = [texts]
135
+ elif isinstance(texts[0], dict):
136
+ to_tokenize = []
137
+ output['text_keys'] = []
138
+ for lookup in texts:
139
+ text_key, text = next(iter(lookup.items()))
140
+ to_tokenize.append(text)
141
+ output['text_keys'].append(text_key)
142
+ to_tokenize = [to_tokenize]
143
+ else:
144
+ batch1, batch2 = [], []
145
+ for text_tuple in texts:
146
+ batch1.append(text_tuple[0])
147
+ batch2.append(text_tuple[1])
148
+ to_tokenize = [batch1, batch2]
149
+
150
+ #strip
151
+ to_tokenize = [[str(s).strip() for s in col] for col in to_tokenize]
152
+
153
+ #Lowercase
154
+ if self.do_lower_case:
155
+ to_tokenize = [[s.lower() for s in col] for col in to_tokenize]
156
+
157
+ output.update(self.tokenizer(*to_tokenize, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_seq_length))
158
+ return output
159
+ '''
160
+
161
+ def get_config_dict(self):
162
+ return {key: self.__dict__[key] for key in self.config_keys}
163
+
164
+ def save(self, output_path: str):
165
+ self.auto_model.save_pretrained(output_path)
166
+ self.processor.save_pretrained(output_path)
167
+
168
+ with open(os.path.join(output_path, 'sentence_bert_config.json'), 'w') as fOut:
169
+ json.dump(self.get_config_dict(), fOut, indent=2)
170
+
171
+ @staticmethod
172
+ def load(input_path: str):
173
+ #Old classes used other config names than 'sentence_bert_config.json'
174
+ for config_name in ['sentence_bert_config.json', 'sentence_roberta_config.json', 'sentence_distilbert_config.json', 'sentence_camembert_config.json', 'sentence_albert_config.json', 'sentence_xlm-roberta_config.json', 'sentence_xlnet_config.json']:
175
+ sbert_config_path = os.path.join(input_path, config_name)
176
+ if os.path.exists(sbert_config_path):
177
+ break
178
+
179
+ with open(sbert_config_path) as fIn:
180
+ config = json.load(fIn)
181
+ return MarkuplmTransformerForConMATH(model_name_or_path=input_path, **config)
182
+
183
+
184
+
185
+
186
+