tskolm commited on
Commit
afa237b
·
1 Parent(s): 7ccbcc9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -5,21 +5,26 @@ import sys
5
  import urllib
6
  import json
7
  import torch
 
8
 
9
  def generate(tokenizer, model, text, features):
10
- generated = tokenizer("<|startoftext|> <|titlestart|>{}<|titleend|>".format(text), return_tensors="pt").input_ids
11
  sample_outputs = model.generate(
12
  generated, do_sample=True, top_k=50,
13
  max_length=features['max_length'], top_p=features['top_p'], temperature=features['t'] / 100.0, num_return_sequences=features['num'],
14
  )
15
  for i, sample_output in enumerate(sample_outputs):
16
- decoded = tokenizer.decode(sample_output, skip_special_tokens=True).replace('\\\\', '\\').split(text)[1]
17
- st.write(decoded)
 
18
 
19
 
20
  def load_model():
21
  tokenizer = torch.load('./tokenizer.pt')
22
- model = torch.load('./model.pt', map_location=torch.device('cpu'))
 
 
 
23
  return tokenizer, model
24
 
25
 
 
5
  import urllib
6
  import json
7
  import torch
8
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
9
 
10
  def generate(tokenizer, model, text, features):
11
+ generated = tokenizer("<|startoftext|> <|titlestart|>{}<|titleend|><|authornamebegin|>".format(text), return_tensors="pt").input_ids
12
  sample_outputs = model.generate(
13
  generated, do_sample=True, top_k=50,
14
  max_length=features['max_length'], top_p=features['top_p'], temperature=features['t'] / 100.0, num_return_sequences=features['num'],
15
  )
16
  for i, sample_output in enumerate(sample_outputs):
17
+ decoded = tokenizer.decode(sample_output, skip_special_tokens=False)
18
+ autor, text = decoded.split('<|authornamebegin|>')[1].split('<|authornameend|>')
19
+ st.markdown('**' + author.strip() + '**: ' + text.replace('<|endoftext|>', '').replace('<|pad|>', '').strip())
20
 
21
 
22
  def load_model():
23
  tokenizer = torch.load('./tokenizer.pt')
24
+ config = GPT2Config.from_json_file('./config.json')
25
+ model = GPT2LMHeadModel(config)
26
+ state_dict = torch.load('./pytorch_model.bin', map_location=torch.device('cpu'))
27
+ model.load_state_dict(state_dict)
28
  return tokenizer, model
29
 
30