Spaces:
Runtime error
Runtime error
streetyogi
commited on
Commit
·
c9b77aa
1
Parent(s):
3f8fe73
Update inference_server.py
Browse files- inference_server.py +1 -1
inference_server.py
CHANGED
@@ -26,7 +26,7 @@ def predict(input_text: str):
|
|
26 |
logits = outputs[0]
|
27 |
|
28 |
# Find the input string that is most similar to the new input string, according to the BERT LM
|
29 |
-
similarity_scores = torch.nn.functional.
|
30 |
logits[:, -1, :], dim=1)
|
31 |
_, prediction_index = torch.max(similarity_scores, dim=0)
|
32 |
prediction = list(strings)[prediction_index]
|
|
|
26 |
logits = outputs[0]
|
27 |
|
28 |
# Find the input string that is most similar to the new input string, according to the BERT LM
|
29 |
+
similarity_scores = torch.nn.functional.cosine_similarity(logits[:, 0, :],
|
30 |
logits[:, -1, :], dim=1)
|
31 |
_, prediction_index = torch.max(similarity_scores, dim=0)
|
32 |
prediction = list(strings)[prediction_index]
|