Spaces:
Runtime error
Runtime error
Commit
·
d358457
1
Parent(s):
c9b77aa
Update inference_server.py
Browse files- inference_server.py +4 -1
inference_server.py
CHANGED
@@ -12,11 +12,14 @@ strings = set() # Set to store all input strings
|
|
12 |
# Load the BERT LM and set it to eval mode
|
13 |
model = transformers.BertModel.from_pretrained('bert-base-cased')
|
14 |
model.eval()
|
|
|
|
|
|
|
15 |
def predict(input_text: str):
|
16 |
# Add the new input string to the set of strings
|
17 |
strings.add(input_text)
|
18 |
# Convert the input strings to input tensors for the BERT LM
|
19 |
-
input_tensors =
|
20 |
pad_to_max_length=True, return_tensors='pt')
|
21 |
input_ids = input_tensors['input_ids']
|
22 |
|
|
|
12 |
# Load the BERT LM and set it to eval mode
|
13 |
model = transformers.BertModel.from_pretrained('bert-base-cased')
|
14 |
model.eval()
|
15 |
+
|
16 |
+
# Load the BERT tokenizer
|
17 |
+
tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-cased')
|
18 |
def predict(input_text: str):
|
19 |
# Add the new input string to the set of strings
|
20 |
strings.add(input_text)
|
21 |
# Convert the input strings to input tensors for the BERT LM
|
22 |
+
input_tensors = tokenizer.batch_encode_plus(list(strings), max_length=512,
|
23 |
pad_to_max_length=True, return_tensors='pt')
|
24 |
input_ids = input_tensors['input_ids']
|
25 |
|