Spaces:
Runtime error
Runtime error
streetyogi
commited on
Commit
·
6429cc8
1
Parent(s):
e4bf79d
Update inference_server.py
Browse files- inference_server.py +1 -2
inference_server.py
CHANGED
@@ -28,8 +28,7 @@ def predict(input_text: str):
|
|
28 |
logits = output[0]
|
29 |
|
30 |
# Find the input string that is most similar to the new input string, according to the BERT LM
|
31 |
-
similarity_scores =
|
32 |
-
torch.nn.functional.csine_similarity(logits[:, 0, :],
|
33 |
logits[:, -1, :], dim=1)
|
34 |
_, prediction_index = torch.max(similarity_scores, dim=0)
|
35 |
prediction = list(strings)[prediction_index]
|
|
|
28 |
logits = output[0]
|
29 |
|
30 |
# Find the input string that is most similar to the new input string, according to the BERT LM
|
31 |
+
similarity_scores = torch.nn.functional.csine_similarity(logits[:, 0, :],
|
|
|
32 |
logits[:, -1, :], dim=1)
|
33 |
_, prediction_index = torch.max(similarity_scores, dim=0)
|
34 |
prediction = list(strings)[prediction_index]
|