File size: 1,733 Bytes
c38c542
dde26ac
3e26431
 
d9bb831
3e26431
 
 
c38c542
 
d9bb831
 
 
420d681
c38c542
 
d9bb831
e4bf79d
d9bb831
 
420d681
d9bb831
 
 
 
420d681
d9bb831
6429cc8
d9bb831
 
 
 
077a630
c38c542
3e26431
 
 
 
bb2c44a
dde26ac
 
 
c38c542
dde26ac
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import pickle
import logging 
import uvicorn
from fastapi import FastAPI
import transformers

app = FastAPI()

strings = set() # Set to store all input strings

# Load the BERT LM and set it to eval mode
model = transformers.BertModel.from_pretrained('bert-base-cased')
model.eval()
def predict(input_text: str):
    # Add the new input string to the set of strings
    strings.add(input_text)
    # Convert the input strings to input tensors for the BERT LM
    input_tensors = transformers.BertTokenizer.from_pretrained('bert-base-cased').batch_encode_plus(list(strings), max_length=512,
                                                                                    pad_to_max_length=True, return_tensors='pt')
    input_ids = input_tensors['input_ids']

    # Use the BERT LM to generate for all input strings
    with torch.no_grad():
        outputs = model(input_ids)
        logits = output[0]

    # Find the input string that is most similar to the new input string, according to the BERT LM
    similarity_scores = torch.nn.functional.csine_similarity(logits[:, 0, :],
                                            logits[:, -1, :], dim=1)
    _, prediction_index = torch.max(similarity_scores, dim=0)
    prediction = list(strings)[prediction_index]    
  
    return {"prediction": prediction, "num_strings": len(strings)}
    
# Here you can do things such as load your models

@app.get("/")
def read_root(input_text):
    logging.info("Received request with input_text: %s", input_text)
    try:
        result = predict(input_text)
        logging.info("Prediction made: %s", result)
        return result
    except Exception as e:
        logging.error("An error occured: %s", e)
        return {"error": str(e)}