me / inference_server.py
streetyogi's picture
Update inference_server.py
420d681
raw
history blame
690 Bytes
from sklearn.linear_model import SGDClassifier
import uvicorn
from fastapi import FastAPI
app = FastAPI()
def predict(input_text: str):
data = [[ord(c) for c in input_text]] # Convert the string to a list of ASCII values
model = train(data)
# Make a prediction
prediction = model.predict([[ord(c) for c in 'abc']]) # Convert the input string to a list of ASCII values
return {"prediction": prediction}
def train(X):
model = SGDClassifier()
model.fit(X, X) # In this case, we are using the input data as the labels
return model
# Here you can do things such as load your models
@app.get("/")
def read_root(input_text):
return {f"Hello {input_text}!"}