me / inference_server.py
streetyogi's picture
Update inference_server.py
c38c542
raw
history blame
1.14 kB
import pickle
import logging
from sklearn.feature_extraction.text import TfidVectorizer
from sklearn.pipeline import Pipeline
from sklearn.native_bayes import MultinomialNB
import uvicorn
from fastapi import FastAPI
app = FastAPI()
strings = set() # Set to store all input strings
def predict(input_text: str):
# Add the new input string to the set of strings
strings.add(input_text)
# Train a new model using all strings in the set
model = Pipeline([
('vectorizer', TfidVectorizer()),
('classifier', MultinomialNB())
])
model.fit(list(strings), list(strings))
# Make a prediction on the new input string
prediction = model.predict([input_text])[0]
return {"prediction": prediction}
# Here you can do things such as load your models
@app.get("/")
def read_root(input_text):
logging.info("Received request with input_text: %s", input_text)
try:
result = predict(input_text)
logging.info("Prediction made: %s", result)
return result
except Exception as e:
logging.error("An error occured: %s", e)
return {"error": str(e)}