streetyogi commited on
Commit
d9bb831
·
1 Parent(s): 077a630

Update inference_server.py

Browse files
Files changed (1) hide show
  1. inference_server.py +20 -8
inference_server.py CHANGED
@@ -5,24 +5,36 @@ from sklearn.pipeline import Pipeline
5
  from sklearn.naive_bayes import MultinomialNB
6
  import uvicorn
7
  from fastapi import FastAPI
 
8
 
9
  app = FastAPI()
10
 
11
  strings = set() # Set to store all input strings
12
 
 
 
 
13
  def predict(input_text: str):
14
  # Add the new input string to the set of strings
15
  strings.add(input_text)
16
- # Train a new model using all strings in the set
17
- model = Pipeline([
18
- ('vectorizer', TfidfVectorizer()),
19
- ('classifier', MultinomialNB())
20
- ])
21
- model.fit(list(strings), list(strings))
22
 
23
- # Make a prediction on the new input string
24
- prediction = model.predict([input_text])[0]
 
 
25
 
 
 
 
 
 
 
 
26
  return {"prediction": prediction, "num_strings": len(strings)}
27
 
28
  # Here you can do things such as load your models
 
5
  from sklearn.naive_bayes import MultinomialNB
6
  import uvicorn
7
  from fastapi import FastAPI
8
+ import transformers
9
 
10
  app = FastAPI()
11
 
12
  strings = set() # Set to store all input strings
13
 
14
+ # Load the BERT LM and set it to eval mode
15
+ model = transformers.BertModel.from_pretrained('bert-base-cased')
16
+ model.eval()
17
  def predict(input_text: str):
18
  # Add the new input string to the set of strings
19
  strings.add(input_text)
20
+ # Convert the input strings to input tensors for the BERT LM
21
+ input_tensors =
22
+ transformers.BertTokenizer.from_pretrained('bert-base-cased').batch_encode_plus(list(strings), max_length=512,
23
+ pad_to_max_length=True, return_tensors='pt')
24
+ input_ids = input_tensors['input_ids']
 
25
 
26
+ # Use the BERT LM to generate for all input strings
27
+ with torch.no_grad():
28
+ outputs = model(input_ids)
29
+ logits = output[0]
30
 
31
+ # Find the input string that is most similar to the new input string, according to the BERT LM
32
+ similarity_scores =
33
+ torch.nn.functional.csine_similarity(logits[:, 0, :],
34
+ logits[:, -1, :], dim=1)
35
+ _, prediction_index = torch.max(similarity_scores, dim=0)
36
+ prediction = list(strings)[prediction_index]
37
+
38
  return {"prediction": prediction, "num_strings": len(strings)}
39
 
40
  # Here you can do things such as load your models