streetyogi commited on
Commit
c38c542
·
1 Parent(s): 0e3f8e5

Update inference_server.py

Browse files
Files changed (1) hide show
  1. inference_server.py +19 -11
inference_server.py CHANGED
@@ -1,22 +1,30 @@
 
1
  import logging
2
- from sklearn.linear_model import SGDClassifier
 
 
3
  import uvicorn
4
  from fastapi import FastAPI
5
 
6
  app = FastAPI()
7
 
 
 
8
  def predict(input_text: str):
9
- data = [[ord(c) for c in input_text]] # Convert the string to a list of ASCII values
10
- model = train(data)
11
- # Make a prediction
12
- prediction = model.predict([[ord(c) for c in 'abc']]) # Convert the input string to a list of ASCII values
13
- return {"prediction": prediction}
 
 
 
14
 
15
- def train(X):
16
- model = SGDClassifier()
17
- model.fit(X, X) # In this case, we are using the input data as the labels
18
- return model
19
 
 
 
20
  # Here you can do things such as load your models
21
 
22
  @app.get("/")
@@ -25,7 +33,7 @@ def read_root(input_text):
25
  try:
26
  result = predict(input_text)
27
  logging.info("Prediction made: %s", result)
28
- return {"result": 1}
29
  except Exception as e:
30
  logging.error("An error occured: %s", e)
31
  return {"error": str(e)}
 
1
+ import pickle
2
  import logging
3
+ from sklearn.feature_extraction.text import TfidVectorizer
4
+ from sklearn.pipeline import Pipeline
5
+ from sklearn.native_bayes import MultinomialNB
6
  import uvicorn
7
  from fastapi import FastAPI
8
 
9
  app = FastAPI()
10
 
11
+ strings = set() # Set to store all input strings
12
+
13
  def predict(input_text: str):
14
+ # Add the new input string to the set of strings
15
+ strings.add(input_text)
16
+ # Train a new model using all strings in the set
17
+ model = Pipeline([
18
+ ('vectorizer', TfidVectorizer()),
19
+ ('classifier', MultinomialNB())
20
+ ])
21
+ model.fit(list(strings), list(strings))
22
 
23
+ # Make a prediction on the new input string
24
+ prediction = model.predict([input_text])[0]
 
 
25
 
26
+ return {"prediction": prediction}
27
+
28
  # Here you can do things such as load your models
29
 
30
  @app.get("/")
 
33
  try:
34
  result = predict(input_text)
35
  logging.info("Prediction made: %s", result)
36
+ return result
37
  except Exception as e:
38
  logging.error("An error occured: %s", e)
39
  return {"error": str(e)}