File size: 2,146 Bytes
c38c542
dde26ac
3e26431
 
d9bb831
82eb6ea
3e848f0
3e26431
 
 
c38c542
 
d9bb831
 
 
d358457
 
 
420d681
3e848f0
12a9883
3e848f0
 
 
12a9883
3e848f0
 
12a9883
 
 
 
d9bb831
d358457
d9bb831
 
420d681
d9bb831
 
 
3f8fe73
420d681
d9bb831
c9b77aa
d9bb831
 
 
 
077a630
c38c542
3e26431
 
 
 
bb2c44a
dde26ac
 
 
c38c542
dde26ac
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import pickle
import logging 
import uvicorn
from fastapi import FastAPI
import transformers
import torch
import fcntl

app = FastAPI()

strings = set() # Set to store all input strings

# Load the BERT LM and set it to eval mode
model = transformers.BertModel.from_pretrained('bert-base-cased')
model.eval()

# Load the BERT tokenizer
tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-cased')
def predict(input_text: str):
    # Open the file in append mode
    with open('strings.txt','a') as f:
        # Lock the file
        fcntl.flock(f, fcntl.LOCK_EX)
        # Add the new input string to the file
        f.write(input_text + '\n')
        # Unlock the file
        fcntl.flock(f, fcntl.LOCK_UN)
        
    # Read all the strings from the file
    with open('strings.txt', 'r') as f:
        strings = set(f.read().splitlines())
    # Convert the input strings to input tensors for the BERT LM
    input_tensors = tokenizer.batch_encode_plus(list(strings), max_length=512,
                                                                                    pad_to_max_length=True, return_tensors='pt')
    input_ids = input_tensors['input_ids']

    # Use the BERT LM to generate for all input strings
    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs[0]

    # Find the input string that is most similar to the new input string, according to the BERT LM
    similarity_scores = torch.nn.functional.cosine_similarity(logits[:, 0, :],
                                            logits[:, -1, :], dim=1)
    _, prediction_index = torch.max(similarity_scores, dim=0)
    prediction = list(strings)[prediction_index]    
  
    return {"prediction": prediction, "num_strings": len(strings)}
    
# Here you can do things such as load your models

@app.get("/")
def read_root(input_text):
    logging.info("Received request with input_text: %s", input_text)
    try:
        result = predict(input_text)
        logging.info("Prediction made: %s", result)
        return result
    except Exception as e:
        logging.error("An error occured: %s", e)
        return {"error": str(e)}