sagawa commited on
Commit
c454156
·
1 Parent(s): b823009

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -8
app.py CHANGED
@@ -11,7 +11,7 @@ import pandas as pd
11
  from tqdm.auto import tqdm
12
  import tokenizers
13
  import transformers
14
- from transformers import AutoTokenizer, AutoConfig, AutoModel, T5EncoderModel, get_linear_schedule_with_warmup
15
  import datasets
16
  from datasets import load_dataset, load_metric
17
  import sentencepiece
@@ -102,25 +102,55 @@ if st.button('predict'):
102
  self.config = torch.load(config_path)
103
  if pretrained:
104
  if 't5' in cfg.model:
105
- self.model = T5EncoderModel.from_pretrained(CFG.pretrained_model_name_or_path)
106
  else:
107
  self.model = AutoModel.from_pretrained(CFG.pretrained_model_name_or_path)
108
  else:
109
  if 't5' in cfg.model:
110
- self.model = T5EncoderModel.from_pretrained('sagawa/ZINC-t5')
111
  else:
112
  self.model = AutoModel.from_config(self.config)
113
  self.model.resize_token_embeddings(len(cfg.tokenizer))
114
  self.fc_dropout1 = nn.Dropout(cfg.fc_dropout)
115
- self.fc1 = nn.Linear(self.config.hidden_size, self.config.hidden_size)
116
  self.fc_dropout2 = nn.Dropout(cfg.fc_dropout)
117
- self.fc2 = nn.Linear(self.config.hidden_size, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  def forward(self, inputs):
120
- outputs = self.model(**inputs)
 
 
 
 
 
121
  last_hidden_states = outputs[0]
122
- output = self.fc1(self.fc_dropout1(last_hidden_states)[:, 0, :].view(-1, self.config.hidden_size))
123
- output = self.fc2(self.fc_dropout2(output))
 
 
 
124
  return output
125
 
126
 
 
11
  from tqdm.auto import tqdm
12
  import tokenizers
13
  import transformers
14
+ from transformers import AutoTokenizer, AutoConfig, AutoModel, T5EncoderModel, get_linear_schedule_with_warmup, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
15
  import datasets
16
  from datasets import load_dataset, load_metric
17
  import sentencepiece
 
102
  self.config = torch.load(config_path)
103
  if pretrained:
104
  if 't5' in cfg.model:
105
+ self.model = T5ForConditionalGeneration.from_pretrained(CFG.pretrained_model_name_or_path)
106
  else:
107
  self.model = AutoModel.from_pretrained(CFG.pretrained_model_name_or_path)
108
  else:
109
  if 't5' in cfg.model:
110
+ self.model = T5ForConditionalGeneration.from_pretrained('sagawa/ZINC-t5')
111
  else:
112
  self.model = AutoModel.from_config(self.config)
113
  self.model.resize_token_embeddings(len(cfg.tokenizer))
114
  self.fc_dropout1 = nn.Dropout(cfg.fc_dropout)
115
+ self.fc1 = nn.Linear(self.config.hidden_size, self.config.hidden_size//2)
116
  self.fc_dropout2 = nn.Dropout(cfg.fc_dropout)
117
+
118
+ self.fc2 = nn.Linear(self.config.hidden_size, self.config.hidden_size//2)
119
+ self.fc3 = nn.Linear(self.config.hidden_size//2*2, self.config.hidden_size)
120
+ self.fc4 = nn.Linear(self.config.hidden_size, self.config.hidden_size)
121
+ self.fc5 = nn.Linear(self.config.hidden_size, 1)
122
+
123
+ self._init_weights(self.fc1)
124
+ self._init_weights(self.fc2)
125
+ self._init_weights(self.fc3)
126
+ self._init_weights(self.fc4)
127
+
128
+ def _init_weights(self, module):
129
+ if isinstance(module, nn.Linear):
130
+ module.weight.data.normal_(mean=0.0, std=0.01)
131
+ if module.bias is not None:
132
+ module.bias.data.zero_()
133
+ elif isinstance(module, nn.Embedding):
134
+ module.weight.data.normal_(mean=0.0, std=0.01)
135
+ if module.padding_idx is not None:
136
+ module.weight.data[module.padding_idx].zero_()
137
+ elif isinstance(module, nn.LayerNorm):
138
+ module.bias.data.zero_()
139
+ module.weight.data.fill_(1.0)
140
 
141
  def forward(self, inputs):
142
+ encoder_outputs = self.model.encoder(**inputs)
143
+ encoder_hidden_states = encoder_outputs[0]
144
+ outputs = self.model.decoder(input_ids=torch.full((inputs['input_ids'].size(0),1),
145
+ self.config.decoder_start_token_id,
146
+ dtype=torch.long,
147
+ device=device), encoder_hidden_states=encoder_hidden_states)
148
  last_hidden_states = outputs[0]
149
+ output1 = self.fc1(self.fc_dropout1(last_hidden_states).view(-1, self.config.hidden_size))
150
+ output2 = self.fc2(encoder_hidden_states[:, 0, :].view(-1, self.config.hidden_size))
151
+ output = self.fc3(self.fc_dropout2(torch.hstack((output1, output2))))
152
+ output = self.fc4(output)
153
+ output = self.fc5(output)
154
  return output
155
 
156