streetyogi commited on
Commit
6429cc8
·
1 Parent(s): e4bf79d

Update inference_server.py

Browse files
Files changed (1) hide show
  1. inference_server.py +1 -2
inference_server.py CHANGED
@@ -28,8 +28,7 @@ def predict(input_text: str):
28
  logits = output[0]
29
 
30
  # Find the input string that is most similar to the new input string, according to the BERT LM
31
- similarity_scores =
32
- torch.nn.functional.csine_similarity(logits[:, 0, :],
33
  logits[:, -1, :], dim=1)
34
  _, prediction_index = torch.max(similarity_scores, dim=0)
35
  prediction = list(strings)[prediction_index]
 
28
  logits = output[0]
29
 
30
  # Find the input string that is most similar to the new input string, according to the BERT LM
31
+ similarity_scores = torch.nn.functional.csine_similarity(logits[:, 0, :],
 
32
  logits[:, -1, :], dim=1)
33
  _, prediction_index = torch.max(similarity_scores, dim=0)
34
  prediction = list(strings)[prediction_index]