Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -11,7 +11,7 @@ import pandas as pd
|
|
11 |
from tqdm.auto import tqdm
|
12 |
import tokenizers
|
13 |
import transformers
|
14 |
-
from transformers import AutoTokenizer, AutoConfig, AutoModel, T5EncoderModel, get_linear_schedule_with_warmup
|
15 |
import datasets
|
16 |
from datasets import load_dataset, load_metric
|
17 |
import sentencepiece
|
@@ -102,25 +102,55 @@ if st.button('predict'):
|
|
102 |
self.config = torch.load(config_path)
|
103 |
if pretrained:
|
104 |
if 't5' in cfg.model:
|
105 |
-
self.model =
|
106 |
else:
|
107 |
self.model = AutoModel.from_pretrained(CFG.pretrained_model_name_or_path)
|
108 |
else:
|
109 |
if 't5' in cfg.model:
|
110 |
-
self.model =
|
111 |
else:
|
112 |
self.model = AutoModel.from_config(self.config)
|
113 |
self.model.resize_token_embeddings(len(cfg.tokenizer))
|
114 |
self.fc_dropout1 = nn.Dropout(cfg.fc_dropout)
|
115 |
-
self.fc1 = nn.Linear(self.config.hidden_size, self.config.hidden_size)
|
116 |
self.fc_dropout2 = nn.Dropout(cfg.fc_dropout)
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
def forward(self, inputs):
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
121 |
last_hidden_states = outputs[0]
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
124 |
return output
|
125 |
|
126 |
|
|
|
11 |
from tqdm.auto import tqdm
|
12 |
import tokenizers
|
13 |
import transformers
|
14 |
+
from transformers import AutoTokenizer, AutoConfig, AutoModel, T5EncoderModel, get_linear_schedule_with_warmup, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
|
15 |
import datasets
|
16 |
from datasets import load_dataset, load_metric
|
17 |
import sentencepiece
|
|
|
102 |
self.config = torch.load(config_path)
|
103 |
if pretrained:
|
104 |
if 't5' in cfg.model:
|
105 |
+
self.model = T5ForConditionalGeneration.from_pretrained(CFG.pretrained_model_name_or_path)
|
106 |
else:
|
107 |
self.model = AutoModel.from_pretrained(CFG.pretrained_model_name_or_path)
|
108 |
else:
|
109 |
if 't5' in cfg.model:
|
110 |
+
self.model = T5ForConditionalGeneration.from_pretrained('sagawa/ZINC-t5')
|
111 |
else:
|
112 |
self.model = AutoModel.from_config(self.config)
|
113 |
self.model.resize_token_embeddings(len(cfg.tokenizer))
|
114 |
self.fc_dropout1 = nn.Dropout(cfg.fc_dropout)
|
115 |
+
self.fc1 = nn.Linear(self.config.hidden_size, self.config.hidden_size//2)
|
116 |
self.fc_dropout2 = nn.Dropout(cfg.fc_dropout)
|
117 |
+
|
118 |
+
self.fc2 = nn.Linear(self.config.hidden_size, self.config.hidden_size//2)
|
119 |
+
self.fc3 = nn.Linear(self.config.hidden_size//2*2, self.config.hidden_size)
|
120 |
+
self.fc4 = nn.Linear(self.config.hidden_size, self.config.hidden_size)
|
121 |
+
self.fc5 = nn.Linear(self.config.hidden_size, 1)
|
122 |
+
|
123 |
+
self._init_weights(self.fc1)
|
124 |
+
self._init_weights(self.fc2)
|
125 |
+
self._init_weights(self.fc3)
|
126 |
+
self._init_weights(self.fc4)
|
127 |
+
|
128 |
+
def _init_weights(self, module):
|
129 |
+
if isinstance(module, nn.Linear):
|
130 |
+
module.weight.data.normal_(mean=0.0, std=0.01)
|
131 |
+
if module.bias is not None:
|
132 |
+
module.bias.data.zero_()
|
133 |
+
elif isinstance(module, nn.Embedding):
|
134 |
+
module.weight.data.normal_(mean=0.0, std=0.01)
|
135 |
+
if module.padding_idx is not None:
|
136 |
+
module.weight.data[module.padding_idx].zero_()
|
137 |
+
elif isinstance(module, nn.LayerNorm):
|
138 |
+
module.bias.data.zero_()
|
139 |
+
module.weight.data.fill_(1.0)
|
140 |
|
141 |
def forward(self, inputs):
|
142 |
+
encoder_outputs = self.model.encoder(**inputs)
|
143 |
+
encoder_hidden_states = encoder_outputs[0]
|
144 |
+
outputs = self.model.decoder(input_ids=torch.full((inputs['input_ids'].size(0),1),
|
145 |
+
self.config.decoder_start_token_id,
|
146 |
+
dtype=torch.long,
|
147 |
+
device=device), encoder_hidden_states=encoder_hidden_states)
|
148 |
last_hidden_states = outputs[0]
|
149 |
+
output1 = self.fc1(self.fc_dropout1(last_hidden_states).view(-1, self.config.hidden_size))
|
150 |
+
output2 = self.fc2(encoder_hidden_states[:, 0, :].view(-1, self.config.hidden_size))
|
151 |
+
output = self.fc3(self.fc_dropout2(torch.hstack((output1, output2))))
|
152 |
+
output = self.fc4(output)
|
153 |
+
output = self.fc5(output)
|
154 |
return output
|
155 |
|
156 |
|