Spaces:
Runtime error
Runtime error
File size: 2,146 Bytes
c38c542 dde26ac 3e26431 d9bb831 82eb6ea 3e848f0 3e26431 c38c542 d9bb831 d358457 420d681 3e848f0 12a9883 3e848f0 12a9883 3e848f0 12a9883 d9bb831 d358457 d9bb831 420d681 d9bb831 3f8fe73 420d681 d9bb831 c9b77aa d9bb831 077a630 c38c542 3e26431 bb2c44a dde26ac c38c542 dde26ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import pickle
import logging
import uvicorn
from fastapi import FastAPI
import transformers
import torch
import fcntl
app = FastAPI()
strings = set() # Set to store all input strings
# Load the BERT LM and set it to eval mode
model = transformers.BertModel.from_pretrained('bert-base-cased')
model.eval()
# Load the BERT tokenizer
tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-cased')
def predict(input_text: str):
# Open the file in append mode
with open('strings.txt','a') as f:
# Lock the file
fcntl.flock(f, fcntl.LOCK_EX)
# Add the new input string to the file
f.write(input_text + '\n')
# Unlock the file
fcntl.flock(f, fcntl.LOCK_UN)
# Read all the strings from the file
with open('strings.txt', 'r') as f:
strings = set(f.read().splitlines())
# Convert the input strings to input tensors for the BERT LM
input_tensors = tokenizer.batch_encode_plus(list(strings), max_length=512,
pad_to_max_length=True, return_tensors='pt')
input_ids = input_tensors['input_ids']
# Use the BERT LM to generate for all input strings
with torch.no_grad():
outputs = model(input_ids)
logits = outputs[0]
# Find the input string that is most similar to the new input string, according to the BERT LM
similarity_scores = torch.nn.functional.cosine_similarity(logits[:, 0, :],
logits[:, -1, :], dim=1)
_, prediction_index = torch.max(similarity_scores, dim=0)
prediction = list(strings)[prediction_index]
return {"prediction": prediction, "num_strings": len(strings)}
# Here you can do things such as load your models
@app.get("/")
def read_root(input_text):
logging.info("Received request with input_text: %s", input_text)
try:
result = predict(input_text)
logging.info("Prediction made: %s", result)
return result
except Exception as e:
logging.error("An error occured: %s", e)
return {"error": str(e)} |