sagawa commited on
Commit
6e92bdc
·
1 Parent(s): 2548c76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -1
app.py CHANGED
@@ -35,7 +35,13 @@ def seed_everything(seed=42):
35
  torch.backends.cudnn.deterministic = True
36
  seed_everything(seed=CFG.seed)
37
 
38
-
 
 
 
 
 
 
39
  input_compound = CFG.input_data
40
  min_length = min(input_compound.find('CATALYST') - input_compound.find(':') - 10, 0)
41
  inp = tokenizer(input_compound, return_tensors='pt').to(device)
 
35
  torch.backends.cudnn.deterministic = True
36
  seed_everything(seed=CFG.seed)
37
 
38
+ tokenizer = AutoTokenizer.from_pretrained(CFG.model_name_or_path, return_tensors='pt')
39
+
40
+ if CFG.model == 't5':
41
+ model = AutoModelForSeq2SeqLM.from_pretrained(CFG.model_name_or_path).to(device)
42
+ elif CFG.model == 'deberta':
43
+ model = EncoderDecoderModel.from_pretrained(CFG.model_name_or_path).to(device)
44
+
45
  input_compound = CFG.input_data
46
  min_length = min(input_compound.find('CATALYST') - input_compound.find(':') - 10, 0)
47
  inp = tokenizer(input_compound, return_tensors='pt').to(device)