streetyogi commited on
Commit
d358457
·
1 Parent(s): c9b77aa

Update inference_server.py

Browse files
Files changed (1) hide show
  1. inference_server.py +4 -1
inference_server.py CHANGED
@@ -12,11 +12,14 @@ strings = set() # Set to store all input strings
12
  # Load the BERT LM and set it to eval mode
13
  model = transformers.BertModel.from_pretrained('bert-base-cased')
14
  model.eval()
 
 
 
15
  def predict(input_text: str):
16
  # Add the new input string to the set of strings
17
  strings.add(input_text)
18
  # Convert the input strings to input tensors for the BERT LM
19
- input_tensors = transformers.BertTokenizer.from_pretrained('bert-base-cased').batch_encode_plus(list(strings), max_length=512,
20
  pad_to_max_length=True, return_tensors='pt')
21
  input_ids = input_tensors['input_ids']
22
 
 
12
  # Load the BERT LM and set it to eval mode
13
  model = transformers.BertModel.from_pretrained('bert-base-cased')
14
  model.eval()
15
+
16
+ # Load the BERT tokenizer
17
+ tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-cased')
18
  def predict(input_text: str):
19
  # Add the new input string to the set of strings
20
  strings.add(input_text)
21
  # Convert the input strings to input tensors for the BERT LM
22
+ input_tensors = tokenizer.batch_encode_plus(list(strings), max_length=512,
23
  pad_to_max_length=True, return_tensors='pt')
24
  input_ids = input_tensors['input_ids']
25